mirror of https://github.com/smithy-lang/smithy-rs
Add incomplete Event Stream support with working Amazon Transcribe example (#653)
* Add incomplete Event Stream support with working Amazon Transcribe example * Make the raw response in SdkError generic * Fix XmlBindingTraitSerializerGeneratorTest * Make the build aware of the SMITHYRS_EXPERIMENTAL_EVENTSTREAM switch * Fix SigV4SigningCustomizationTest * Update changelog * Fix build when SMITHYRS_EXPERIMENTAL_EVENTSTREAM is not set * Add initial unit test for EventStreamUnmarshallerGenerator * Add event header unmarshalling support * Don't pull in event stream dependencies by default * Only add event stream signer to config for services that need it * Move event stream inlineables into smithy-eventstream * Fix some clippy lints * Transform event stream unions * Fix crash in SigV4SigningDecorator * Add test for unmarshalling errors * Incorporate CR feedback
This commit is contained in:
parent
b119782a65
commit
3b8f69c18d
|
@ -3,6 +3,7 @@ vNext (Month Day, Year)
|
|||
|
||||
**New this week**
|
||||
|
||||
- (When complete) Add Event Stream support (#653, #xyz)
|
||||
- (When complete) Add profile file provider for region (#594, #xyz)
|
||||
|
||||
v0.21 (August 19th, 2021)
|
||||
|
|
|
@ -7,12 +7,17 @@ license = "Apache-2.0"
|
|||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
sign-eventstream = ["smithy-eventstream", "aws-sigv4/sign-eventstream"]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
http = "0.2.2"
|
||||
aws-sigv4 = { path = "../aws-sigv4" }
|
||||
aws-auth = { path = "../aws-auth" }
|
||||
aws-types = { path = "../aws-types" }
|
||||
smithy-http = { path = "../../../rust-runtime/smithy-http" }
|
||||
smithy-eventstream = { path = "../../../rust-runtime/smithy-eventstream", optional = true }
|
||||
# Trying this out as an experiment. thiserror can be removed and replaced with hand written error
|
||||
# implementations and it is not a breaking change.
|
||||
thiserror = "1"
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
use crate::middleware::Signature;
|
||||
use aws_auth::Credentials;
|
||||
use aws_sigv4::event_stream::sign_message;
|
||||
use aws_sigv4::SigningParams;
|
||||
use aws_types::region::SigningRegion;
|
||||
use aws_types::SigningService;
|
||||
use smithy_eventstream::frame::{Message, SignMessage, SignMessageError};
|
||||
use smithy_http::property_bag::PropertyBag;
|
||||
use std::sync::{Arc, Mutex, MutexGuard};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// Event Stream SigV4 signing implementation.
|
||||
#[derive(Debug)]
|
||||
pub struct SigV4Signer {
|
||||
properties: Arc<Mutex<PropertyBag>>,
|
||||
last_signature: Option<String>,
|
||||
}
|
||||
|
||||
impl SigV4Signer {
|
||||
pub fn new(properties: Arc<Mutex<PropertyBag>>) -> Self {
|
||||
Self {
|
||||
properties,
|
||||
last_signature: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SignMessage for SigV4Signer {
|
||||
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
|
||||
let properties = PropertyAccessor(self.properties.lock().unwrap());
|
||||
if self.last_signature.is_none() {
|
||||
// The Signature property should exist in the property bag for all Event Stream requests.
|
||||
self.last_signature = Some(properties.expect::<Signature>().as_ref().into())
|
||||
}
|
||||
|
||||
// Every single one of these values would have been retrieved during the initial request,
|
||||
// so we can safely assume they all exist in the property bag at this point.
|
||||
let credentials = properties.expect::<Credentials>();
|
||||
let region = properties.expect::<SigningRegion>();
|
||||
let signing_service = properties.expect::<SigningService>();
|
||||
let time = properties
|
||||
.get::<SystemTime>()
|
||||
.copied()
|
||||
.unwrap_or_else(SystemTime::now);
|
||||
let params = SigningParams {
|
||||
access_key: credentials.access_key_id(),
|
||||
secret_key: credentials.secret_access_key(),
|
||||
security_token: credentials.session_token(),
|
||||
region: region.as_ref(),
|
||||
service_name: signing_service.as_ref(),
|
||||
date_time: time.into(),
|
||||
settings: (),
|
||||
};
|
||||
|
||||
let (signed_message, signature) =
|
||||
sign_message(&message, self.last_signature.as_ref().unwrap(), ¶ms).into_parts();
|
||||
self.last_signature = Some(signature);
|
||||
|
||||
Ok(signed_message)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(EventStream): Make a new type around `Arc<Mutex<PropertyBag>>` called `SharedPropertyBag`
|
||||
// and abstract the mutex away entirely.
|
||||
struct PropertyAccessor<'a>(MutexGuard<'a, PropertyBag>);
|
||||
|
||||
impl<'a> PropertyAccessor<'a> {
|
||||
fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
|
||||
self.0.get::<T>()
|
||||
}
|
||||
|
||||
fn expect<T: Send + Sync + 'static>(&self) -> &T {
|
||||
self.get::<T>()
|
||||
.expect("property should have been inserted into property bag via middleware")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::event_stream::SigV4Signer;
|
||||
use crate::middleware::Signature;
|
||||
use aws_auth::Credentials;
|
||||
use aws_types::region::Region;
|
||||
use aws_types::region::SigningRegion;
|
||||
use aws_types::SigningService;
|
||||
use smithy_eventstream::frame::{HeaderValue, Message, SignMessage};
|
||||
use smithy_http::property_bag::PropertyBag;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, UNIX_EPOCH};
|
||||
|
||||
#[test]
|
||||
fn sign_message() {
|
||||
let region = Region::new("us-east-1");
|
||||
let mut properties = PropertyBag::new();
|
||||
properties.insert(region.clone());
|
||||
properties.insert(UNIX_EPOCH + Duration::new(1611160427, 0));
|
||||
properties.insert(SigningService::from_static("transcribe"));
|
||||
properties.insert(Credentials::from_keys("AKIAfoo", "bar", None));
|
||||
properties.insert(SigningRegion::from(region));
|
||||
properties.insert(Signature::new("initial-signature".into()));
|
||||
|
||||
let mut signer = SigV4Signer::new(Arc::new(Mutex::new(properties)));
|
||||
let mut signatures = Vec::new();
|
||||
for _ in 0..5 {
|
||||
let signed = signer
|
||||
.sign(Message::new(&b"identical message"[..]))
|
||||
.unwrap();
|
||||
if let HeaderValue::ByteArray(signature) = signed
|
||||
.headers()
|
||||
.iter()
|
||||
.find(|h| h.name().as_str() == ":chunk-signature")
|
||||
.unwrap()
|
||||
.value()
|
||||
{
|
||||
signatures.push(signature.clone());
|
||||
} else {
|
||||
panic!("failed to get the :chunk-signature")
|
||||
}
|
||||
}
|
||||
for i in 1..signatures.len() {
|
||||
assert_ne!(signatures[i - 1], signatures[i]);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,5 +7,8 @@
|
|||
//!
|
||||
//! In the future, additional signature algorithms can be enabled as Cargo Features.
|
||||
|
||||
#[cfg(feature = "sign-eventstream")]
|
||||
pub mod event_stream;
|
||||
|
||||
pub mod middleware;
|
||||
pub mod signer;
|
||||
|
|
|
@ -24,8 +24,10 @@ impl Signature {
|
|||
pub fn new(signature: String) -> Self {
|
||||
Self(signature)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
impl AsRef<str> for Signature {
|
||||
fn as_ref(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ description = "AWS SigV4 signer"
|
|||
[features]
|
||||
sign-http = ["http", "http-body", "percent-encoding", "form_urlencoded"]
|
||||
sign-eventstream = ["smithy-eventstream"]
|
||||
default = ["sign-http", "sign-eventstream"]
|
||||
default = ["sign-http"]
|
||||
|
||||
[dependencies]
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
|
||||
|
|
|
@ -32,5 +32,5 @@ object AwsRuntimeType {
|
|||
val S3Errors by lazy { RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("s3_errors")) }
|
||||
}
|
||||
|
||||
fun RuntimeConfig.awsRuntimeDependency(name: String, features: List<String> = listOf()): CargoDependency =
|
||||
fun RuntimeConfig.awsRuntimeDependency(name: String, features: Set<String> = setOf()): CargoDependency =
|
||||
CargoDependency(name, awsRoot().crateLocation(), features = features)
|
||||
|
|
|
@ -49,7 +49,8 @@ class IntegrationTestDependencies(
|
|||
override fun section(section: LibRsSection) = when (section) {
|
||||
LibRsSection.Body -> writable {
|
||||
if (hasTests) {
|
||||
val smithyClient = CargoDependency.SmithyClient(runtimeConfig).copy(features = listOf("test-util"), scope = DependencyScope.Dev)
|
||||
val smithyClient = CargoDependency.SmithyClient(runtimeConfig)
|
||||
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
|
||||
addDependency(smithyClient)
|
||||
addDependency(SerdeJson)
|
||||
addDependency(Tokio)
|
||||
|
@ -63,5 +64,5 @@ class IntegrationTestDependencies(
|
|||
}
|
||||
|
||||
val Criterion = CargoDependency("criterion", CratesIo("0.3"), scope = DependencyScope.Dev)
|
||||
val SerdeJson = CargoDependency("serde_json", CratesIo("1"), features = emptyList(), scope = DependencyScope.Dev)
|
||||
val Tokio = CargoDependency("tokio", CratesIo("1"), features = listOf("macros", "test-util"), scope = DependencyScope.Dev)
|
||||
val SerdeJson = CargoDependency("serde_json", CratesIo("1"), features = emptySet(), scope = DependencyScope.Dev)
|
||||
val Tokio = CargoDependency("tokio", CratesIo("1"), features = setOf("macros", "test-util"), scope = DependencyScope.Dev)
|
||||
|
|
|
@ -13,12 +13,14 @@ import software.amazon.smithy.model.shapes.OperationShape
|
|||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.traits.OptionalAuthTrait
|
||||
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.rustlang.asType
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization
|
||||
import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection
|
||||
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
|
||||
|
@ -28,11 +30,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfi
|
|||
import software.amazon.smithy.rust.codegen.smithy.letIf
|
||||
import software.amazon.smithy.rust.codegen.util.dq
|
||||
import software.amazon.smithy.rust.codegen.util.expectTrait
|
||||
import software.amazon.smithy.rust.codegen.util.hasEventStreamOperations
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.isInputEventStream
|
||||
|
||||
/**
|
||||
* The SigV4SigningDecorator:
|
||||
* - adds a `signing_service()` method to `config` to return the default signing service
|
||||
* - adds a `new_event_stream_signer()` method to `config` to create an Event Stream SigV4 signer
|
||||
* - sets the `SigningService` during operation construction
|
||||
* - sets a default `OperationSigningConfig` A future enhancement will customize this for specific services that need
|
||||
* different behavior.
|
||||
|
@ -47,8 +52,12 @@ class SigV4SigningDecorator : RustCodegenDecorator {
|
|||
protocolConfig: ProtocolConfig,
|
||||
baseCustomizations: List<ConfigCustomization>
|
||||
): List<ConfigCustomization> {
|
||||
return baseCustomizations.letIf(applies(protocolConfig)) {
|
||||
it + SigV4SigningConfig(protocolConfig.serviceShape.expectTrait())
|
||||
return baseCustomizations.letIf(applies(protocolConfig)) { customizations ->
|
||||
customizations + SigV4SigningConfig(
|
||||
protocolConfig.runtimeConfig,
|
||||
protocolConfig.serviceShape.hasEventStreamOperations(protocolConfig.model),
|
||||
protocolConfig.serviceShape.expectTrait()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,26 +67,63 @@ class SigV4SigningDecorator : RustCodegenDecorator {
|
|||
baseCustomizations: List<OperationCustomization>
|
||||
): List<OperationCustomization> {
|
||||
return baseCustomizations.letIf(applies(protocolConfig)) {
|
||||
it + SigV4SigningFeature(operation, protocolConfig.runtimeConfig, protocolConfig.serviceShape, protocolConfig.model)
|
||||
it + SigV4SigningFeature(
|
||||
protocolConfig.model,
|
||||
operation,
|
||||
protocolConfig.runtimeConfig,
|
||||
protocolConfig.serviceShape,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class SigV4SigningConfig(private val sigV4Trait: SigV4Trait) : ConfigCustomization() {
|
||||
class SigV4SigningConfig(
|
||||
runtimeConfig: RuntimeConfig,
|
||||
private val serviceHasEventStream: Boolean,
|
||||
private val sigV4Trait: SigV4Trait
|
||||
) : ConfigCustomization() {
|
||||
private val codegenScope = arrayOf(
|
||||
"SigV4Signer" to RuntimeType(
|
||||
"SigV4Signer",
|
||||
runtimeConfig.awsRuntimeDependency("aws-sig-auth", setOf("sign-eventstream")),
|
||||
"aws_sig_auth::event_stream"
|
||||
),
|
||||
"PropertyBag" to RuntimeType(
|
||||
"PropertyBag",
|
||||
CargoDependency.SmithyHttp(runtimeConfig),
|
||||
"smithy_http::property_bag"
|
||||
)
|
||||
)
|
||||
|
||||
override fun section(section: ServiceConfig): Writable {
|
||||
return when (section) {
|
||||
is ServiceConfig.ConfigImpl -> writable {
|
||||
rust(
|
||||
rustTemplate(
|
||||
"""
|
||||
/// The signature version 4 service signing name to use in the credential scope when signing requests.
|
||||
///
|
||||
/// The signing service may be overidden by the `Endpoint`, or by specifying a custom [`SigningService`](aws_types::SigningService) during
|
||||
/// operation construction
|
||||
/// The signing service may be overridden by the `Endpoint`, or by specifying a custom
|
||||
/// [`SigningService`](aws_types::SigningService) during operation construction
|
||||
pub fn signing_service(&self) -> &'static str {
|
||||
${sigV4Trait.name.dq()}
|
||||
}
|
||||
"""
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
if (serviceHasEventStream) {
|
||||
rustTemplate(
|
||||
"""
|
||||
/// Creates a new Event Stream `SignMessage` implementor.
|
||||
pub fn new_event_stream_signer(
|
||||
&self,
|
||||
properties: std::sync::Arc<std::sync::Mutex<#{PropertyBag}>>
|
||||
) -> #{SigV4Signer} {
|
||||
#{SigV4Signer}::new(properties)
|
||||
}
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
else -> emptySection
|
||||
}
|
||||
|
@ -95,10 +141,10 @@ fun disableDoubleEncode(service: ServiceShape) = when {
|
|||
}
|
||||
|
||||
class SigV4SigningFeature(
|
||||
private val model: Model,
|
||||
private val operation: OperationShape,
|
||||
runtimeConfig: RuntimeConfig,
|
||||
private val service: ServiceShape,
|
||||
model: Model
|
||||
) :
|
||||
OperationCustomization() {
|
||||
private val codegenScope =
|
||||
|
@ -111,9 +157,9 @@ class SigV4SigningFeature(
|
|||
is OperationSection.MutateRequest -> writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
##[allow(unused_mut)]
|
||||
let mut signing_config = #{sig_auth}::signer::OperationSigningConfig::default_config();
|
||||
""",
|
||||
##[allow(unused_mut)]
|
||||
let mut signing_config = #{sig_auth}::signer::OperationSigningConfig::default_config();
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
if (needsAmzSha256(service)) {
|
||||
|
@ -128,6 +174,12 @@ class SigV4SigningFeature(
|
|||
"${section.request}.properties_mut().insert(#{sig_auth}::signer::SignableBody::UnsignedPayload);",
|
||||
*codegenScope
|
||||
)
|
||||
} else if (operation.isInputEventStream(model)) {
|
||||
// TODO(EventStream): Is this actually correct for all Event Stream operations?
|
||||
rustTemplate(
|
||||
"${section.request}.properties_mut().insert(#{sig_auth}::signer::SignableBody::Bytes(&[]));",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
// some operations are either unsigned or optionally signed:
|
||||
val authSchemes = serviceIndex.getEffectiveAuthSchemes(service, operation)
|
||||
|
@ -140,9 +192,9 @@ class SigV4SigningFeature(
|
|||
}
|
||||
rustTemplate(
|
||||
"""
|
||||
${section.request}.properties_mut().insert(signing_config);
|
||||
${section.request}.properties_mut().insert(#{aws_types}::SigningService::from_static(${section.config}.signing_service()));
|
||||
""",
|
||||
${section.request}.properties_mut().insert(signing_config);
|
||||
${section.request}.properties_mut().insert(#{aws_types}::SigningService::from_static(${section.config}.signing_service()));
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
|
|
|
@ -14,13 +14,19 @@ import software.amazon.smithy.rust.codegen.testutil.unitTest
|
|||
internal class SigV4SigningCustomizationTest {
|
||||
@Test
|
||||
fun `generates a valid config`() {
|
||||
val project = stubConfigProject(SigV4SigningConfig(SigV4Trait.builder().name("test-service").build()))
|
||||
val project = stubConfigProject(
|
||||
SigV4SigningConfig(
|
||||
AwsTestRuntimeConfig,
|
||||
true,
|
||||
SigV4Trait.builder().name("test-service").build()
|
||||
)
|
||||
)
|
||||
project.lib {
|
||||
it.unitTest(
|
||||
"""
|
||||
let conf = crate::config::Config::builder().build();
|
||||
assert_eq!(conf.signing_service(), "test-service");
|
||||
"""
|
||||
let conf = crate::config::Config::builder().build();
|
||||
assert_eq!(conf.signing_service(), "test-service");
|
||||
"""
|
||||
)
|
||||
}
|
||||
project.compileAndTest()
|
||||
|
|
|
@ -235,6 +235,8 @@ task("generateSmithyBuild") {
|
|||
projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(awsServices))
|
||||
}
|
||||
inputs.property("servicelist", awsServices.sortedBy { it.module }.toString())
|
||||
// TODO(EventStream): Remove this when removing SMITHYRS_EXPERIMENTAL_EVENTSTREAM
|
||||
inputs.property("_eventStreamCacheInvalidation", System.getenv("SMITHYRS_EXPERIMENTAL_EVENTSTREAM") ?: "0")
|
||||
inputs.dir(projectDir.resolve("aws-models"))
|
||||
outputs.file(projectDir.resolve("smithy-build.json"))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
[package]
|
||||
name = "transcribestreaming"
|
||||
version = "0.1.0"
|
||||
authors = ["John DiSanti <jdisanti@amazon.com>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
aws-auth-providers = { path = "../../build/aws-sdk/aws-auth-providers" }
|
||||
aws-sdk-transcribestreaming = { package = "aws-sdk-transcribestreaming", path = "../../build/aws-sdk/transcribestreaming" }
|
||||
aws-types = { path = "../../build/aws-sdk/aws-types" }
|
||||
|
||||
async-stream = "0.3"
|
||||
bytes = "1"
|
||||
hound = "3.4"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tracing-subscriber = "0.2.18"
|
Binary file not shown.
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
use async_stream::stream;
|
||||
use aws_sdk_transcribestreaming::model::{AudioEvent, AudioStream, LanguageCode, MediaEncoding};
|
||||
use aws_sdk_transcribestreaming::{Blob, Client, Config, Region};
|
||||
use bytes::BufMut;
|
||||
use std::time::Duration;
|
||||
|
||||
const CHUNK_SIZE: usize = 8192;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let input_stream = stream! {
|
||||
let pcm = pcm_data();
|
||||
for chunk in pcm.chunks(CHUNK_SIZE) {
|
||||
// Sleeping isn't necessary, but emphasizes the streaming aspect of this
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
|
||||
}
|
||||
// Must send an empty chunk at the end
|
||||
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(Vec::new())).build()));
|
||||
};
|
||||
|
||||
let config = Config::builder()
|
||||
.region(Region::from_static("us-west-2"))
|
||||
.build();
|
||||
let client = Client::from_conf(config);
|
||||
|
||||
let mut output = client
|
||||
.start_stream_transcription()
|
||||
.language_code(LanguageCode::EnGb)
|
||||
.media_sample_rate_hertz(8000)
|
||||
.media_encoding(MediaEncoding::Pcm)
|
||||
.audio_stream(input_stream.into())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
loop {
|
||||
match output.transcript_result_stream.recv().await {
|
||||
Ok(Some(transcription)) => {
|
||||
println!("Received transcription response:\n{:?}\n", transcription)
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(err) => println!("Received an error: {:?}", err),
|
||||
}
|
||||
}
|
||||
println!("Done.")
|
||||
}
|
||||
|
||||
fn pcm_data() -> Vec<u8> {
|
||||
let audio = include_bytes!("../audio/hello-transcribe-8000.wav");
|
||||
let reader = hound::WavReader::new(&audio[..]).unwrap();
|
||||
let samples_result: hound::Result<Vec<i16>> = reader.into_samples::<i16>().collect();
|
||||
|
||||
let mut pcm: Vec<u8> = Vec::new();
|
||||
for sample in samples_result.unwrap() {
|
||||
pcm.put_i16_le(sample);
|
||||
}
|
||||
pcm
|
||||
}
|
|
@ -106,6 +106,8 @@ task("generateSmithyBuild") {
|
|||
doFirst {
|
||||
projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(CodegenTests))
|
||||
}
|
||||
// TODO(EventStream): Remove this when removing SMITHYRS_EXPERIMENTAL_EVENTSTREAM
|
||||
inputs.property("_eventStreamCacheInvalidation", System.getenv("SMITHYRS_EXPERIMENTAL_EVENTSTREAM") ?: "0")
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -88,6 +88,9 @@ class InlineDependency(
|
|||
private fun forRustFile(name: String, vararg additionalDependencies: RustDependency) =
|
||||
forRustFile(name, "inlineable", *additionalDependencies)
|
||||
|
||||
fun eventStream(runtimeConfig: RuntimeConfig) =
|
||||
forRustFile("event_stream", CargoDependency.SmithyEventStream(runtimeConfig))
|
||||
|
||||
fun jsonErrors(runtimeConfig: RuntimeConfig) =
|
||||
forRustFile("json_errors", CargoDependency.Http, CargoDependency.SmithyTypes(runtimeConfig))
|
||||
|
||||
|
@ -118,8 +121,15 @@ data class CargoDependency(
|
|||
private val location: DependencyLocation,
|
||||
val scope: DependencyScope = DependencyScope.Compile,
|
||||
val optional: Boolean = false,
|
||||
private val features: List<String> = listOf()
|
||||
val features: Set<String> = emptySet()
|
||||
) : RustDependency(name) {
|
||||
val key: Triple<String, DependencyLocation, DependencyScope> get() = Triple(name, location, scope)
|
||||
|
||||
fun canMergeWith(other: CargoDependency): Boolean = key == other.key
|
||||
|
||||
fun withFeature(feature: String): CargoDependency {
|
||||
return copy(features = features.toMutableSet().apply { add(feature) })
|
||||
}
|
||||
|
||||
override fun version(): String = when (location) {
|
||||
is CratesIo -> location.version
|
||||
|
@ -173,12 +183,15 @@ data class CargoDependency(
|
|||
val Md5 = CargoDependency("md5", CratesIo("0.7"))
|
||||
val FastRand = CargoDependency("fastrand", CratesIo("1"))
|
||||
val Http: CargoDependency = CargoDependency("http", CratesIo("0.2"))
|
||||
val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14"))
|
||||
val HyperWithStream: CargoDependency = Hyper.withFeature("stream")
|
||||
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"), optional = true)
|
||||
fun SmithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("types")
|
||||
|
||||
fun SmithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("client")
|
||||
fun SmithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("eventstream")
|
||||
fun SmithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http")
|
||||
fun SmithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-tower")
|
||||
fun SmithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("client")
|
||||
|
||||
fun ProtocolTestHelpers(runtimeConfig: RuntimeConfig) = CargoDependency(
|
||||
"protocol-test-helpers", runtimeConfig.runtimeCrateLocation.crateLocation(), scope = DependencyScope.Dev
|
||||
|
|
|
@ -101,8 +101,10 @@ fun CodegenWriterDelegator<RustWriter>.finalize(
|
|||
this.useFileWriter("src/lib.rs", "crate::lib") { writer ->
|
||||
LibRsGenerator(settings.moduleDescription, modules, libRsCustomizations).render(writer)
|
||||
}
|
||||
val cargoDependencies =
|
||||
this.dependencies.map { RustDependency.fromSymbolDependency(it) }.filterIsInstance<CargoDependency>().distinct()
|
||||
val cargoDependencies = mergeDependencyFeatures(
|
||||
this.dependencies.map { RustDependency.fromSymbolDependency(it) }
|
||||
.filterIsInstance<CargoDependency>().distinct()
|
||||
)
|
||||
this.useFileWriter("Cargo.toml") {
|
||||
val cargoToml = CargoTomlGenerator(
|
||||
settings,
|
||||
|
@ -114,3 +116,18 @@ fun CodegenWriterDelegator<RustWriter>.finalize(
|
|||
}
|
||||
flushWriters()
|
||||
}
|
||||
|
||||
private fun CargoDependency.mergeWith(other: CargoDependency): CargoDependency {
|
||||
check(key == other.key)
|
||||
return copy(
|
||||
features = features + other.features,
|
||||
optional = optional && other.optional
|
||||
)
|
||||
}
|
||||
|
||||
internal fun mergeDependencyFeatures(cargoDependencies: List<CargoDependency>): List<CargoDependency> =
|
||||
cargoDependencies.groupBy { it.key }
|
||||
.mapValues { group -> group.value.reduce { acc, next -> acc.mergeWith(next) } }
|
||||
.values
|
||||
.toList()
|
||||
.sortedBy { it.name }
|
||||
|
|
|
@ -28,7 +28,10 @@ import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
|
|||
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.AddErrorMessage
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
import software.amazon.smithy.rust.codegen.util.CommandFailed
|
||||
import software.amazon.smithy.rust.codegen.util.getTrait
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
|
@ -78,6 +81,9 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
|
|||
private fun baselineTransform(model: Model) =
|
||||
model.let(RecursiveShapeBoxer::transform)
|
||||
.letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform)
|
||||
.let(OperationNormalizer::transform)
|
||||
.let(RemoveEventStreamOperations::transform)
|
||||
.let(EventStreamNormalizer::transform)
|
||||
|
||||
fun execute() {
|
||||
logger.info("generating Rust client...")
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.rustlang.RustType
|
||||
import software.amazon.smithy.rust.codegen.rustlang.render
|
||||
import software.amazon.smithy.rust.codegen.rustlang.stripOuter
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
|
||||
import software.amazon.smithy.rust.codegen.util.getTrait
|
||||
import software.amazon.smithy.rust.codegen.util.isEventStream
|
||||
import software.amazon.smithy.rust.codegen.util.isInputEventStream
|
||||
|
||||
/**
|
||||
* Wrapping symbol provider to wrap modeled types with the smithy-http Event Stream send/receive types.
|
||||
*/
|
||||
class EventStreamSymbolProvider(
|
||||
private val runtimeConfig: RuntimeConfig,
|
||||
base: RustSymbolProvider,
|
||||
private val model: Model
|
||||
) : WrappingSymbolProvider(base) {
|
||||
override fun toSymbol(shape: Shape): Symbol {
|
||||
val initial = super.toSymbol(shape)
|
||||
|
||||
// We only want to wrap with Event Stream types when dealing with member shapes
|
||||
if (shape is MemberShape && shape.isEventStream(model)) {
|
||||
// Determine if the member has a container that is a synthetic input or output
|
||||
val operationShape = model.expectShape(shape.container).let { maybeInputOutput ->
|
||||
val operationId = maybeInputOutput.getTrait<SyntheticInputTrait>()?.operation
|
||||
?: maybeInputOutput.getTrait<SyntheticOutputTrait>()?.operation
|
||||
operationId?.let { model.expectShape(it, OperationShape::class.java) }
|
||||
}
|
||||
// If we find an operation shape, then we can wrap the type
|
||||
if (operationShape != null) {
|
||||
val error = operationShape.errorSymbol(this).toSymbol()
|
||||
val errorFmt = error.rustType().render(fullyQualified = true)
|
||||
val innerFmt = initial.rustType().stripOuter<RustType.Option>().render(fullyQualified = true)
|
||||
val outer = when (shape.isInputEventStream(model)) {
|
||||
true -> "EventStreamInput<$innerFmt>"
|
||||
else -> "Receiver<$innerFmt, $errorFmt>"
|
||||
}
|
||||
val rustType = RustType.Opaque(outer, "smithy_http::event_stream")
|
||||
return initial.toBuilder()
|
||||
.name(rustType.name)
|
||||
.rustType(rustType)
|
||||
.addReference(error)
|
||||
.addReference(initial)
|
||||
.addDependency(CargoDependency.SmithyHttp(runtimeConfig).withFeature("event-stream"))
|
||||
.build()
|
||||
}
|
||||
}
|
||||
|
||||
return initial
|
||||
}
|
||||
}
|
|
@ -161,6 +161,15 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
|
|||
val HttpRequestBuilder = Http("request::Builder")
|
||||
val HttpResponseBuilder = Http("response::Builder")
|
||||
|
||||
val Hyper = CargoDependency.Hyper.asType()
|
||||
|
||||
fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType =
|
||||
RuntimeType(
|
||||
"Receiver",
|
||||
dependency = CargoDependency.SmithyHttp(runtimeConfig),
|
||||
"smithy_http::event_stream"
|
||||
)
|
||||
|
||||
fun jsonErrors(runtimeConfig: RuntimeConfig) =
|
||||
forInlineDependency(InlineDependency.jsonErrors(runtimeConfig))
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ class RustCodegenPlugin : SmithyBuildPlugin {
|
|||
companion object {
|
||||
fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig) =
|
||||
SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig)
|
||||
.let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) }
|
||||
.let { StreamingShapeSymbolProvider(it, model) }
|
||||
.let { BaseSymbolMetadataProvider(it) }
|
||||
.let { StreamingShapeMetadataProvider(it, model) }
|
||||
|
|
|
@ -27,7 +27,10 @@ val ClippyAllowLints = listOf(
|
|||
"should_implement_trait",
|
||||
|
||||
// protocol tests use silly names like `baz`, don't flag that
|
||||
"blacklisted_name"
|
||||
"blacklisted_name",
|
||||
|
||||
// Forcing use of `vec![]` can make codegen harder in some cases
|
||||
"vec_init_then_push",
|
||||
)
|
||||
|
||||
class AllowClippyLints : LibRsCustomization() {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.smithy.generators
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
|
@ -51,13 +52,14 @@ class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
|
|||
fun MemberShape.setterName(): String = "set_${this.memberName.toSnakeCase()}"
|
||||
|
||||
class BuilderGenerator(
|
||||
val model: Model,
|
||||
private val model: Model,
|
||||
private val symbolProvider: RustSymbolProvider,
|
||||
private val shape: StructureShape
|
||||
) {
|
||||
private val members: List<MemberShape> = shape.allMembers.values.toList()
|
||||
private val runtimeConfig = symbolProvider.config().runtimeConfig
|
||||
private val members: List<MemberShape> = shape.allMembers.values.toList()
|
||||
private val structureSymbol = symbolProvider.toSymbol(shape)
|
||||
|
||||
fun render(writer: RustWriter) {
|
||||
val symbol = symbolProvider.toSymbol(shape)
|
||||
// TODO: figure out exactly what docs we want on a the builder module
|
||||
|
@ -104,6 +106,53 @@ class BuilderGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(EventStream): [DX] Consider updating builders to take EventInputStream as Into<EventInputStream>
|
||||
private fun renderBuilderMember(writer: RustWriter, member: MemberShape, memberName: String, memberSymbol: Symbol) {
|
||||
// builder members are crate-public to enable using them
|
||||
// directly in serializers/deserializers
|
||||
writer.write("pub(crate) $memberName: #T,", memberSymbol)
|
||||
}
|
||||
|
||||
private fun renderBuilderMemberFn(
|
||||
writer: RustWriter,
|
||||
coreType: RustType,
|
||||
member: MemberShape,
|
||||
memberName: String,
|
||||
memberSymbol: Symbol
|
||||
) {
|
||||
fun builderConverter(coreType: RustType) = when (coreType) {
|
||||
is RustType.String,
|
||||
is RustType.Box -> "input.into()"
|
||||
else -> "input"
|
||||
}
|
||||
|
||||
val signature = when (coreType) {
|
||||
is RustType.String,
|
||||
is RustType.Box -> "(mut self, input: impl Into<${coreType.render(true)}>) -> Self"
|
||||
else -> "(mut self, input: ${coreType.render(true)}) -> Self"
|
||||
}
|
||||
writer.documentShape(member, model)
|
||||
writer.rustBlock("pub fn $memberName$signature") {
|
||||
write("self.$memberName = Some(${builderConverter(coreType)});")
|
||||
write("self")
|
||||
}
|
||||
}
|
||||
|
||||
private fun renderBuilderMemberSetterFn(
|
||||
writer: RustWriter,
|
||||
outerType: RustType,
|
||||
member: MemberShape,
|
||||
memberName: String,
|
||||
memberSymbol: Symbol
|
||||
) {
|
||||
// Render a `set_foo` method. This is useful as a target for code generation, because the argument type
|
||||
// is the same as the resulting member type, and is always optional.
|
||||
val inputType = outerType.asOptional()
|
||||
writer.rustBlock("pub fn ${member.setterName()}(mut self, input: ${inputType.render(true)}) -> Self") {
|
||||
rust("self.$memberName = input; self")
|
||||
}
|
||||
}
|
||||
|
||||
private fun renderBuilder(writer: RustWriter) {
|
||||
val builderName = "Builder"
|
||||
|
||||
|
@ -119,18 +168,10 @@ class BuilderGenerator(
|
|||
val memberName = symbolProvider.toMemberName(member)
|
||||
// All fields in the builder are optional
|
||||
val memberSymbol = symbolProvider.toSymbol(member).makeOptional()
|
||||
// builder members are crate-public to enable using them
|
||||
// directly in serializers/deserializers
|
||||
write("pub(crate) $memberName: #T,", memberSymbol)
|
||||
renderBuilderMember(this, member, memberName, memberSymbol)
|
||||
}
|
||||
}
|
||||
|
||||
fun builderConverter(coreType: RustType) = when (coreType) {
|
||||
is RustType.String,
|
||||
is RustType.Box -> "input.into()"
|
||||
else -> "input"
|
||||
}
|
||||
|
||||
writer.rustBlock("impl $builderName") {
|
||||
members.forEach { member ->
|
||||
// All fields in the builder are optional
|
||||
|
@ -143,26 +184,10 @@ class BuilderGenerator(
|
|||
when (coreType) {
|
||||
is RustType.Vec -> renderVecHelper(memberName, coreType)
|
||||
is RustType.HashMap -> renderMapHelper(memberName, coreType)
|
||||
else -> {
|
||||
val signature = when (coreType) {
|
||||
is RustType.String,
|
||||
is RustType.Box -> "(mut self, input: impl Into<${coreType.render(true)}>) -> Self"
|
||||
else -> "(mut self, input: ${coreType.render(true)}) -> Self"
|
||||
}
|
||||
writer.documentShape(member, model)
|
||||
writer.rustBlock("pub fn $memberName$signature") {
|
||||
write("self.$memberName = Some(${builderConverter(coreType)});")
|
||||
write("self")
|
||||
}
|
||||
}
|
||||
else -> renderBuilderMemberFn(this, coreType, member, memberName, memberSymbol)
|
||||
}
|
||||
|
||||
// Render a `set_foo` method. This is useful as a target for code generation, because the argument type
|
||||
// is the same as the resulting member type, and is always optional.
|
||||
val inputType = outerType.asOptional()
|
||||
writer.rustBlock("pub fn ${member.setterName()}(mut self, input: ${inputType.render(true)}) -> Self") {
|
||||
rust("self.$memberName = input; self")
|
||||
}
|
||||
renderBuilderMemberSetterFn(this, outerType, member, memberName, memberSymbol)
|
||||
}
|
||||
buildFn(this)
|
||||
}
|
||||
|
|
|
@ -195,6 +195,7 @@ abstract class HttpProtocolGenerator(
|
|||
) {
|
||||
withBlock("Ok({", "})") {
|
||||
features.forEach { it.section(OperationSection.MutateInput("self", "_config"))(this) }
|
||||
rust("let properties = std::sync::Arc::new(std::sync::Mutex::new(smithy_http::property_bag::PropertyBag::new()));")
|
||||
rust("let request = self.request_builder_base()?;")
|
||||
withBlock("let body =", ";") {
|
||||
body("self", shape)
|
||||
|
@ -203,7 +204,7 @@ abstract class HttpProtocolGenerator(
|
|||
rust(
|
||||
"""
|
||||
##[allow(unused_mut)]
|
||||
let mut request = #T::Request::new(request.map(#T::from));
|
||||
let mut request = #T::Request::from_parts(request.map(#T::from), properties);
|
||||
""",
|
||||
operationModule, sdkBody
|
||||
)
|
||||
|
|
|
@ -146,7 +146,7 @@ class HttpProtocolTestGenerator(
|
|||
val Tokio = CargoDependency(
|
||||
"tokio",
|
||||
CratesIo("1"),
|
||||
features = listOf("macros", "test-util", "rt"),
|
||||
features = setOf("macros", "test-util", "rt"),
|
||||
scope = DependencyScope.Dev
|
||||
)
|
||||
testModuleWriter.addDependency(Tokio)
|
||||
|
|
|
@ -24,12 +24,12 @@ class UnionGenerator(
|
|||
private val writer: RustWriter,
|
||||
private val shape: UnionShape
|
||||
) {
|
||||
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
|
||||
|
||||
fun render() {
|
||||
renderUnion()
|
||||
}
|
||||
|
||||
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
|
||||
private fun renderUnion() {
|
||||
val unionSymbol = symbolProvider.toSymbol(shape)
|
||||
val containerMeta = unionSymbol.expectRustMetadata()
|
||||
|
|
|
@ -116,12 +116,11 @@ class ServiceConfigGenerator(private val customizations: List<ConfigCustomizatio
|
|||
writer.rustBlock("impl std::fmt::Debug for Config") {
|
||||
rustTemplate(
|
||||
"""
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut config = f.debug_struct("Config");
|
||||
config.finish()
|
||||
}
|
||||
|
||||
"""
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut config = f.debug_struct("Config");
|
||||
config.finish()
|
||||
}
|
||||
"""
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -129,7 +128,7 @@ class ServiceConfigGenerator(private val customizations: List<ConfigCustomizatio
|
|||
rustTemplate(
|
||||
"""
|
||||
pub fn builder() -> Builder { Builder::default() }
|
||||
"""
|
||||
"""
|
||||
)
|
||||
customizations.forEach {
|
||||
it.section(ServiceConfig.ConfigImpl)(this)
|
||||
|
|
|
@ -35,6 +35,8 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
|
|||
import software.amazon.smithy.rust.codegen.smithy.makeOptional
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.EventStreamUnmarshallerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.rustType
|
||||
import software.amazon.smithy.rust.codegen.util.dq
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
|
@ -42,7 +44,11 @@ import software.amazon.smithy.rust.codegen.util.isPrimitive
|
|||
import software.amazon.smithy.rust.codegen.util.isStreaming
|
||||
import software.amazon.smithy.rust.codegen.util.toSnakeCase
|
||||
|
||||
class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val operationShape: OperationShape) {
|
||||
class ResponseBindingGenerator(
|
||||
private val protocol: Protocol,
|
||||
protocolConfig: ProtocolConfig,
|
||||
private val operationShape: OperationShape
|
||||
) {
|
||||
private val runtimeConfig = protocolConfig.runtimeConfig
|
||||
private val symbolProvider = protocolConfig.symbolProvider
|
||||
private val model = protocolConfig.model
|
||||
|
@ -124,6 +130,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
|
|||
* Generate a function to deserialize `[binding]` from the response payload
|
||||
*/
|
||||
fun generateDeserializePayloadFn(
|
||||
operationShape: OperationShape,
|
||||
binding: HttpBindingDescriptor,
|
||||
errorT: RuntimeType,
|
||||
// Deserialize a single structure or union member marked as a payload
|
||||
|
@ -142,7 +149,13 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
|
|||
outputT,
|
||||
errorT
|
||||
) {
|
||||
deserializeStreamingBody(binding)
|
||||
// Streaming unions are Event Streams and should be handled separately
|
||||
val target = model.expectShape(binding.member.target)
|
||||
if (target is UnionShape) {
|
||||
bindEventStreamOutput(operationShape, target)
|
||||
} else {
|
||||
deserializeStreamingBody(binding)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rustWriter.rustBlock("pub fn $fnName(body: &[u8]) -> std::result::Result<#T, #T>", outputT, errorT) {
|
||||
|
@ -157,6 +170,27 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
|
|||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.bindEventStreamOutput(operationShape: OperationShape, target: UnionShape) {
|
||||
val unmarshallerConstructorFn = EventStreamUnmarshallerGenerator(
|
||||
protocol,
|
||||
model,
|
||||
runtimeConfig,
|
||||
symbolProvider,
|
||||
operationShape,
|
||||
target
|
||||
).render()
|
||||
rustTemplate(
|
||||
"""
|
||||
let unmarshaller = #{unmarshallerConstructorFn}();
|
||||
let body = std::mem::replace(body, #{SdkBody}::taken());
|
||||
Ok(#{Receiver}::new(unmarshaller, body))
|
||||
""",
|
||||
"SdkBody" to RuntimeType.sdkBody(runtimeConfig),
|
||||
"unmarshallerConstructorFn" to unmarshallerConstructorFn,
|
||||
"Receiver" to RuntimeType.eventStreamReceiver(runtimeConfig),
|
||||
)
|
||||
}
|
||||
|
||||
private fun RustWriter.deserializeStreamingBody(binding: HttpBindingDescriptor) {
|
||||
val member = binding.member
|
||||
val targetShape = model.expectShape(member.target)
|
||||
|
@ -164,10 +198,10 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
|
|||
rustTemplate(
|
||||
"""
|
||||
// replace the body with an empty body
|
||||
let body = std::mem::replace(body, #{sdk_body}::taken());
|
||||
Ok(#{byte_stream}::new(body))
|
||||
let body = std::mem::replace(body, #{SdkBody}::taken());
|
||||
Ok(#{ByteStream}::new(body))
|
||||
""",
|
||||
"byte_stream" to RuntimeType.byteStream(runtimeConfig), "sdk_body" to RuntimeType.sdkBody(runtimeConfig)
|
||||
"ByteStream" to RuntimeType.byteStream(runtimeConfig), "SdkBody" to RuntimeType.sdkBody(runtimeConfig)
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.smithy.protocols
|
|||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.pattern.UriPattern
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.shapes.ToShapeId
|
||||
import software.amazon.smithy.model.traits.HttpTrait
|
||||
import software.amazon.smithy.model.traits.TimestampFormatTrait
|
||||
|
@ -24,9 +23,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGene
|
|||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier
|
||||
import software.amazon.smithy.rust.codegen.util.inputShape
|
||||
import software.amazon.smithy.rust.codegen.util.orNull
|
||||
|
||||
|
@ -47,17 +43,7 @@ class AwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFac
|
|||
return HttpBoundProtocolGenerator(protocolConfig, AwsJson(protocolConfig, version))
|
||||
}
|
||||
|
||||
private val shapeIfHasMembers: StructureModifier = { _, shape: StructureShape? ->
|
||||
when (shape?.members().isNullOrEmpty()) {
|
||||
true -> null
|
||||
else -> shape
|
||||
}
|
||||
}
|
||||
|
||||
override fun transformModel(model: Model): Model {
|
||||
// For AwsJson10, the body matches 1:1 with the input
|
||||
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
|
||||
}
|
||||
override fun transformModel(model: Model): Model = model
|
||||
|
||||
override fun support(): ProtocolSupport = ProtocolSupport(
|
||||
requestSerialization = true,
|
||||
|
|
|
@ -24,17 +24,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.AwsQueryParser
|
|||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.AwsQuerySerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
import software.amazon.smithy.rust.codegen.util.getTrait
|
||||
|
||||
class AwsQueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
|
||||
override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator =
|
||||
HttpBoundProtocolGenerator(protocolConfig, AwsQueryProtocol(protocolConfig))
|
||||
|
||||
override fun transformModel(model: Model): Model {
|
||||
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
|
||||
}
|
||||
override fun transformModel(model: Model): Model = model
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(
|
||||
|
|
|
@ -22,16 +22,12 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.Ec2QueryParser
|
|||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.Ec2QuerySerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
|
||||
class Ec2QueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
|
||||
override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator =
|
||||
HttpBoundProtocolGenerator(protocolConfig, Ec2QueryProtocol(protocolConfig))
|
||||
|
||||
override fun transformModel(model: Model): Model {
|
||||
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
|
||||
}
|
||||
override fun transformModel(model: Model): Model = model
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(
|
||||
|
|
|
@ -40,6 +40,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError
|
|||
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
|
||||
import software.amazon.smithy.rust.codegen.smithy.isOptional
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.EventStreamMarshallerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember
|
||||
import software.amazon.smithy.rust.codegen.util.dq
|
||||
|
@ -47,6 +48,8 @@ import software.amazon.smithy.rust.codegen.util.expectMember
|
|||
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.inputShape
|
||||
import software.amazon.smithy.rust.codegen.util.isEventStream
|
||||
import software.amazon.smithy.rust.codegen.util.isInputEventStream
|
||||
import software.amazon.smithy.rust.codegen.util.isStreaming
|
||||
import software.amazon.smithy.rust.codegen.util.outputShape
|
||||
import software.amazon.smithy.rust.codegen.util.toSnakeCase
|
||||
|
@ -81,10 +84,12 @@ class HttpBoundProtocolGenerator(
|
|||
"ParseStrict" to RuntimeType.parseStrict(runtimeConfig),
|
||||
"ParseResponse" to RuntimeType.parseResponse(runtimeConfig),
|
||||
"http" to RuntimeType.http,
|
||||
"hyper" to CargoDependency.HyperWithStream.asType(),
|
||||
"operation" to RuntimeType.operationModule(runtimeConfig),
|
||||
"Bytes" to RuntimeType.Bytes,
|
||||
"SdkBody" to RuntimeType.sdkBody(runtimeConfig),
|
||||
"BuildError" to runtimeConfig.operationBuildError()
|
||||
"BuildError" to runtimeConfig.operationBuildError(),
|
||||
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType()
|
||||
)
|
||||
|
||||
override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata {
|
||||
|
@ -103,10 +108,50 @@ class HttpBoundProtocolGenerator(
|
|||
BodyMetadata(takesOwnership = false)
|
||||
} else {
|
||||
val member = inputShape.expectMember(payloadMemberName)
|
||||
serializeViaPayload(member, serializerGenerator)
|
||||
if (operationShape.isInputEventStream(model)) {
|
||||
serializeViaEventStream(operationShape, member, serializerGenerator)
|
||||
} else {
|
||||
serializeViaPayload(member, serializerGenerator)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.serializeViaEventStream(
|
||||
operationShape: OperationShape,
|
||||
memberShape: MemberShape,
|
||||
serializerGenerator: StructuredDataSerializerGenerator
|
||||
): BodyMetadata {
|
||||
val memberName = symbolProvider.toMemberName(memberShape)
|
||||
val unionShape = model.expectShape(memberShape.target, UnionShape::class.java)
|
||||
|
||||
val marshallerConstructorFn = EventStreamMarshallerGenerator(
|
||||
model,
|
||||
runtimeConfig,
|
||||
symbolProvider,
|
||||
unionShape,
|
||||
serializerGenerator
|
||||
).render()
|
||||
|
||||
// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
|
||||
// parameters that are not `@eventHeader` or `@eventPayload`.
|
||||
rustTemplate(
|
||||
"""
|
||||
{
|
||||
let marshaller = #{marshallerConstructorFn}();
|
||||
let signer = _config.new_event_stream_signer(properties.clone());
|
||||
let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> =
|
||||
self.$memberName.into_body_stream(marshaller, signer);
|
||||
let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
|
||||
body
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
"marshallerConstructorFn" to marshallerConstructorFn,
|
||||
"OperationError" to operationShape.errorSymbol(symbolProvider)
|
||||
)
|
||||
return BodyMetadata(takesOwnership = true)
|
||||
}
|
||||
|
||||
private fun RustWriter.serializeViaPayload(
|
||||
member: MemberShape,
|
||||
serializerGenerator: StructuredDataSerializerGenerator
|
||||
|
@ -175,6 +220,10 @@ class HttpBoundProtocolGenerator(
|
|||
BodyMetadata(takesOwnership = true)
|
||||
}
|
||||
is StructureShape, is UnionShape -> {
|
||||
check(
|
||||
!((targetShape as? UnionShape)?.isEventStream() ?: false)
|
||||
) { "Event Streams should be handled further up" }
|
||||
|
||||
// JSON serialize the structure or union targeted
|
||||
rust(
|
||||
"""#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?""",
|
||||
|
@ -430,7 +479,7 @@ class HttpBoundProtocolGenerator(
|
|||
bindings: List<HttpBindingDescriptor>,
|
||||
errorSymbol: RuntimeType,
|
||||
) {
|
||||
val httpBindingGenerator = ResponseBindingGenerator(protocolConfig, operationShape)
|
||||
val httpBindingGenerator = ResponseBindingGenerator(protocol, protocolConfig, operationShape)
|
||||
val structuredDataParser = protocol.structuredDataParser(operationShape)
|
||||
Attribute.AllowUnusedMut.render(this)
|
||||
rust("let mut output = #T::default();", outputShape.builderSymbol(symbolProvider))
|
||||
|
@ -464,7 +513,7 @@ class HttpBoundProtocolGenerator(
|
|||
}
|
||||
|
||||
val err = if (StructureGenerator.fallibleBuilder(outputShape, symbolProvider)) {
|
||||
".map_err(|s|${format(errorSymbol)}::unhandled(s))?"
|
||||
".map_err(${format(errorSymbol)}::unhandled)?"
|
||||
} else ""
|
||||
rust("output.build()$err")
|
||||
}
|
||||
|
@ -509,6 +558,7 @@ class HttpBoundProtocolGenerator(
|
|||
rust("#T($body).map_err(#T::unhandled)", structuredDataParser.payloadParser(member), errorSymbol)
|
||||
}
|
||||
val deserializer = httpBindingGenerator.generateDeserializePayloadFn(
|
||||
operationShape,
|
||||
binding,
|
||||
errorSymbol,
|
||||
docHandler = docShapeHandler,
|
||||
|
|
|
@ -20,17 +20,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGene
|
|||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
|
||||
class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
|
||||
override fun buildProtocolGenerator(
|
||||
protocolConfig: ProtocolConfig
|
||||
): HttpBoundProtocolGenerator = HttpBoundProtocolGenerator(protocolConfig, RestJson(protocolConfig))
|
||||
|
||||
override fun transformModel(model: Model): Model {
|
||||
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
|
||||
}
|
||||
override fun transformModel(model: Model): Model = model
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(
|
||||
|
|
|
@ -21,8 +21,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.RestXmlParserG
|
|||
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.XmlBindingTraitSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
import software.amazon.smithy.rust.codegen.util.expectTrait
|
||||
|
||||
class RestXmlFactory(private val generator: (ProtocolConfig) -> Protocol = { RestXml(it) }) :
|
||||
|
@ -33,9 +31,7 @@ class RestXmlFactory(private val generator: (ProtocolConfig) -> Protocol = { Res
|
|||
return HttpBoundProtocolGenerator(protocolConfig, generator(protocolConfig))
|
||||
}
|
||||
|
||||
override fun transformModel(model: Model): Model {
|
||||
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
|
||||
}
|
||||
override fun transformModel(model: Model): Model = model
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(
|
||||
|
|
|
@ -0,0 +1,288 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.protocols.parse
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.shapes.BooleanShape
|
||||
import software.amazon.smithy.model.shapes.ByteShape
|
||||
import software.amazon.smithy.model.shapes.IntegerShape
|
||||
import software.amazon.smithy.model.shapes.LongShape
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.ShortShape
|
||||
import software.amazon.smithy.model.shapes.StringShape
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.shapes.TimestampShape
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.model.traits.EventHeaderTrait
|
||||
import software.amazon.smithy.model.traits.EventPayloadTrait
|
||||
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.rustlang.withBlock
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait
|
||||
import software.amazon.smithy.rust.codegen.util.dq
|
||||
import software.amazon.smithy.rust.codegen.util.expectTrait
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.toPascalCase
|
||||
|
||||
class EventStreamUnmarshallerGenerator(
|
||||
private val protocol: Protocol,
|
||||
private val model: Model,
|
||||
runtimeConfig: RuntimeConfig,
|
||||
private val symbolProvider: RustSymbolProvider,
|
||||
private val operationShape: OperationShape,
|
||||
private val unionShape: UnionShape,
|
||||
) {
|
||||
private val unionSymbol = symbolProvider.toSymbol(unionShape)
|
||||
private val operationErrorSymbol = operationShape.errorSymbol(symbolProvider)
|
||||
private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig)
|
||||
private val codegenScope = arrayOf(
|
||||
"Blob" to RuntimeType("Blob", CargoDependency.SmithyTypes(runtimeConfig), "smithy_types"),
|
||||
"Error" to RuntimeType("Error", smithyEventStream, "smithy_eventstream::error"),
|
||||
"Header" to RuntimeType("Header", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"HeaderValue" to RuntimeType("HeaderValue", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"expect_fns" to RuntimeType("smithy", smithyEventStream, "smithy_eventstream"),
|
||||
"Message" to RuntimeType("Message", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"SmithyError" to RuntimeType("Error", CargoDependency.SmithyTypes(runtimeConfig), "smithy_types"),
|
||||
"UnmarshallMessage" to RuntimeType("UnmarshallMessage", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"UnmarshalledMessage" to RuntimeType("UnmarshalledMessage", smithyEventStream, "smithy_eventstream::frame"),
|
||||
)
|
||||
|
||||
fun render(): RuntimeType {
|
||||
val unmarshallerType = unionShape.eventStreamUnmarshallerType()
|
||||
return RuntimeType.forInlineFun("${unmarshallerType.name}::new", "event_stream_serde") { inlineWriter ->
|
||||
inlineWriter.renderUnmarshaller(unmarshallerType, unionSymbol)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) {
|
||||
rust(
|
||||
"""
|
||||
##[non_exhaustive]
|
||||
##[derive(Debug)]
|
||||
pub struct ${unmarshallerType.name};
|
||||
|
||||
impl ${unmarshallerType.name} {
|
||||
pub fn new() -> Self {
|
||||
${unmarshallerType.name}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
rustBlockTemplate(
|
||||
"impl #{UnmarshallMessage} for ${unmarshallerType.name}",
|
||||
*codegenScope
|
||||
) {
|
||||
rust("type Output = #T;", unionSymbol)
|
||||
rust("type Error = #T;", operationErrorSymbol)
|
||||
|
||||
rustBlockTemplate(
|
||||
"""
|
||||
fn unmarshall(
|
||||
&self,
|
||||
message: &#{Message}
|
||||
) -> std::result::Result<#{UnmarshalledMessage}<Self::Output, Self::Error>, #{Error}>
|
||||
""",
|
||||
*codegenScope
|
||||
) {
|
||||
rustBlockTemplate(
|
||||
"""
|
||||
let response_headers = #{expect_fns}::parse_response_headers(&message)?;
|
||||
match response_headers.message_type.as_str()
|
||||
""",
|
||||
*codegenScope
|
||||
) {
|
||||
rustBlock("\"event\" => ") {
|
||||
renderUnmarshallEvent()
|
||||
}
|
||||
rustBlock("\"exception\" => ") {
|
||||
renderUnmarshallError()
|
||||
}
|
||||
rustBlock("value => ") {
|
||||
rustTemplate(
|
||||
"return Err(#{Error}::Unmarshalling(format!(\"unrecognized :message-type: {}\", value)));",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshallEvent() {
|
||||
rustBlock("match response_headers.smithy_type.as_str()") {
|
||||
for (member in unionShape.members()) {
|
||||
val target = model.expectShape(member.target, StructureShape::class.java)
|
||||
rustBlock("${member.memberName.dq()} => ") {
|
||||
renderUnmarshallUnionMember(member, target)
|
||||
}
|
||||
}
|
||||
rustBlock("smithy_type => ") {
|
||||
// TODO: Handle this better once unions support unknown variants
|
||||
rustTemplate(
|
||||
"return Err(#{Error}::Unmarshalling(format!(\"unrecognized :event-type: {}\", smithy_type)));",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshallUnionMember(unionMember: MemberShape, unionStruct: StructureShape) {
|
||||
val unionMemberName = unionMember.memberName.toPascalCase()
|
||||
val payloadOnly =
|
||||
unionStruct.members().none { it.hasTrait<EventPayloadTrait>() || it.hasTrait<EventHeaderTrait>() }
|
||||
if (payloadOnly) {
|
||||
withBlock("let parsed = ", ";") {
|
||||
renderParseProtocolPayload(unionMember)
|
||||
}
|
||||
rustTemplate(
|
||||
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(parsed)))",
|
||||
"Output" to unionSymbol,
|
||||
*codegenScope
|
||||
)
|
||||
} else {
|
||||
rust("let mut builder = #T::builder();", symbolProvider.toSymbol(unionStruct))
|
||||
val payloadMember = unionStruct.members().firstOrNull { it.hasTrait<EventPayloadTrait>() }
|
||||
if (payloadMember != null) {
|
||||
renderUnmarshallEventPayload(payloadMember)
|
||||
}
|
||||
val headerMembers = unionStruct.members().filter { it.hasTrait<EventHeaderTrait>() }
|
||||
if (headerMembers.isNotEmpty()) {
|
||||
rustBlock("for header in message.headers()") {
|
||||
rustBlock("match header.name().as_str()") {
|
||||
for (member in headerMembers) {
|
||||
rustBlock("${member.memberName.dq()} => ") {
|
||||
renderUnmarshallEventHeader(member)
|
||||
}
|
||||
}
|
||||
rust("_ => {}")
|
||||
}
|
||||
}
|
||||
}
|
||||
rustTemplate(
|
||||
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(builder.build())))",
|
||||
"Output" to unionSymbol,
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshallEventHeader(member: MemberShape) {
|
||||
val memberName = symbolProvider.toMemberName(member)
|
||||
withBlock("builder = builder.$memberName(", ");") {
|
||||
when (val target = model.expectShape(member.target)) {
|
||||
is BooleanShape -> rustTemplate("#{expect_fns}::expect_bool(header)?", *codegenScope)
|
||||
is ByteShape -> rustTemplate("#{expect_fns}::expect_byte(header)?", *codegenScope)
|
||||
is ShortShape -> rustTemplate("#{expect_fns}::expect_int16(header)?", *codegenScope)
|
||||
is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope)
|
||||
is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope)
|
||||
is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope)
|
||||
is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope)
|
||||
is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope)
|
||||
else -> throw IllegalStateException("unsupported event stream header shape type: $target")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshallEventPayload(member: MemberShape) {
|
||||
// TODO(EventStream): [RPC] Don't blow up on an initial-message that's not part of the union (:event-type will be "initial-request" or "initial-response")
|
||||
// TODO(EventStream): [RPC] Incorporate initial-message into original output (:event-type will be "initial-request" or "initial-response")
|
||||
val memberName = symbolProvider.toMemberName(member)
|
||||
withBlock("builder = builder.$memberName(", ");") {
|
||||
when (model.expectShape(member.target)) {
|
||||
is BlobShape -> {
|
||||
rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope)
|
||||
}
|
||||
is StringShape -> {
|
||||
rustTemplate(
|
||||
"""
|
||||
std::str::from_utf8(message.payload())
|
||||
.map_err(|_| #{Error}::Unmarshalling("message payload is not valid UTF-8".into()))?
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
is UnionShape, is StructureShape -> {
|
||||
renderParseProtocolPayload(member)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderParseProtocolPayload(member: MemberShape) {
|
||||
// TODO(EventStream): Check :content-type against expected content-type, error if unexpected
|
||||
val parser = protocol.structuredDataParser(operationShape).payloadParser(member)
|
||||
val memberName = member.memberName.toPascalCase()
|
||||
rustTemplate(
|
||||
"""
|
||||
#{parser}(&message.payload()[..])
|
||||
.map_err(|err| {
|
||||
#{Error}::Unmarshalling(format!("failed to unmarshall $memberName: {}", err))
|
||||
})?
|
||||
""",
|
||||
"parser" to parser,
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshallError() {
|
||||
val syntheticUnion = unionShape.expectTrait<SyntheticEventStreamUnionTrait>()
|
||||
if (syntheticUnion.errorMembers.isNotEmpty()) {
|
||||
rustBlock("match response_headers.smithy_type.as_str()") {
|
||||
for (member in syntheticUnion.errorMembers) {
|
||||
val target = model.expectShape(member.target, StructureShape::class.java)
|
||||
rustBlock("${member.memberName.dq()} => ") {
|
||||
val parser = protocol.structuredDataParser(operationShape).errorParser(target)
|
||||
if (parser != null) {
|
||||
rust("let mut builder = #T::builder();", symbolProvider.toSymbol(target))
|
||||
// TODO(EventStream): Errors on the operation can be disjoint with errors in the union,
|
||||
// so we need to generate a new top-level Error type for each event stream union.
|
||||
rustTemplate(
|
||||
"""
|
||||
builder = #{parser}(&message.payload()[..], builder)
|
||||
.map_err(|err| {
|
||||
#{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err))
|
||||
})?;
|
||||
return Ok(#{UnmarshalledMessage}::Error(
|
||||
#{OpError}::new(
|
||||
#{OpError}Kind::${member.memberName.toPascalCase()}(builder.build()),
|
||||
#{SmithyError}::builder().build(),
|
||||
)
|
||||
))
|
||||
""",
|
||||
"OpError" to operationErrorSymbol,
|
||||
"parser" to parser,
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
rust("_ => {}")
|
||||
}
|
||||
}
|
||||
// TODO(EventStream): Generic error parsing; will need to refactor `parseGenericError` to
|
||||
// operate on bodies rather than responses. This should be easy for all but restJson,
|
||||
// which pulls the error type out of a header.
|
||||
rust("unimplemented!(\"event stream generic error parsing\")")
|
||||
}
|
||||
|
||||
private fun UnionShape.eventStreamUnmarshallerType(): RuntimeType {
|
||||
val symbol = symbolProvider.toSymbol(this)
|
||||
return RuntimeType("${symbol.name.toPascalCase()}Unmarshaller", null, "crate::event_stream_serde")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.protocols.serialize
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.StringShape
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.model.traits.EventHeaderTrait
|
||||
import software.amazon.smithy.model.traits.EventPayloadTrait
|
||||
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.rustlang.render
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.smithy.rustType
|
||||
import software.amazon.smithy.rust.codegen.util.dq
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.toPascalCase
|
||||
|
||||
// TODO(EventStream): [TEST] Unit test EventStreamMarshallerGenerator
|
||||
class EventStreamMarshallerGenerator(
|
||||
private val model: Model,
|
||||
runtimeConfig: RuntimeConfig,
|
||||
private val symbolProvider: RustSymbolProvider,
|
||||
private val unionShape: UnionShape,
|
||||
private val serializerGenerator: StructuredDataSerializerGenerator,
|
||||
) {
|
||||
private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig)
|
||||
private val codegenScope = arrayOf(
|
||||
"MarshallMessage" to RuntimeType("MarshallMessage", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"Message" to RuntimeType("Message", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"Header" to RuntimeType("Header", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"HeaderValue" to RuntimeType("HeaderValue", smithyEventStream, "smithy_eventstream::frame"),
|
||||
"Error" to RuntimeType("Error", smithyEventStream, "smithy_eventstream::error"),
|
||||
)
|
||||
|
||||
fun render(): RuntimeType {
|
||||
val marshallerType = unionShape.eventStreamMarshallerType()
|
||||
val unionSymbol = symbolProvider.toSymbol(unionShape)
|
||||
|
||||
return RuntimeType.forInlineFun("${marshallerType.name}::new", "event_stream_serde") { inlineWriter ->
|
||||
inlineWriter.renderMarshaller(marshallerType, unionSymbol)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.renderMarshaller(marshallerType: RuntimeType, unionSymbol: Symbol) {
|
||||
rust(
|
||||
"""
|
||||
##[non_exhaustive]
|
||||
##[derive(Debug)]
|
||||
pub struct ${marshallerType.name};
|
||||
|
||||
impl ${marshallerType.name} {
|
||||
pub fn new() -> Self {
|
||||
${marshallerType.name}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
rustBlockTemplate(
|
||||
"impl #{MarshallMessage} for ${marshallerType.name}",
|
||||
*codegenScope
|
||||
) {
|
||||
rust("type Input = ${unionSymbol.rustType().render(fullyQualified = true)};")
|
||||
|
||||
rustBlockTemplate(
|
||||
"fn marshall(&self, input: Self::Input) -> std::result::Result<#{Message}, #{Error}>",
|
||||
*codegenScope
|
||||
) {
|
||||
rust("let mut headers = Vec::new();")
|
||||
addStringHeader(":message-type", "\"event\".into()")
|
||||
rustBlock("let payload = match input") {
|
||||
for (member in unionShape.members()) {
|
||||
val eventType = member.memberName // must be the original name, not the Rust-safe name
|
||||
rustBlock("Self::Input::${member.memberName.toPascalCase()}(inner) => ") {
|
||||
addStringHeader(":event-type", "${eventType.dq()}.into()")
|
||||
val target = model.expectShape(member.target, StructureShape::class.java)
|
||||
serializeEvent(target)
|
||||
}
|
||||
}
|
||||
}
|
||||
rustTemplate("; Ok(#{Message}::new_from_parts(headers, payload))", *codegenScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.serializeEvent(struct: StructureShape) {
|
||||
for (member in struct.members()) {
|
||||
val memberName = symbolProvider.toMemberName(member)
|
||||
val target = model.expectShape(member.target)
|
||||
if (member.hasTrait<EventPayloadTrait>()) {
|
||||
serializeUnionMember(memberName, member, target)
|
||||
} else if (member.hasTrait<EventHeaderTrait>()) {
|
||||
TODO("TODO(EventStream): Implement @eventHeader trait")
|
||||
} else {
|
||||
throw IllegalStateException("Event Stream members must be a header or payload")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.serializeUnionMember(memberName: String, member: MemberShape, target: Shape) {
|
||||
if (target is BlobShape || target is StringShape) {
|
||||
data class PayloadContext(val conversionFn: String, val contentType: String)
|
||||
val ctx = when (target) {
|
||||
is BlobShape -> PayloadContext("into_inner", "application/octet-stream")
|
||||
is StringShape -> PayloadContext("into_bytes", "text/plain")
|
||||
else -> throw IllegalStateException("unreachable")
|
||||
}
|
||||
addStringHeader(":content-type", "${ctx.contentType.dq()}.into()")
|
||||
if (member.isOptional) {
|
||||
rust(
|
||||
"""
|
||||
if let Some(inner_payload) = inner.$memberName {
|
||||
inner_payload.${ctx.conversionFn}()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
"""
|
||||
)
|
||||
} else {
|
||||
rust("inner.$memberName.${ctx.conversionFn}()")
|
||||
}
|
||||
} else {
|
||||
// TODO(EventStream): Select content-type based on protocol
|
||||
addStringHeader(":content-type", "\"TODO\".into()")
|
||||
|
||||
val serializerFn = serializerGenerator.payloadSerializer(member)
|
||||
rustTemplate(
|
||||
"""
|
||||
#{serializerFn}(&inner.$memberName)
|
||||
.map_err(|err| #{Error}::Marshalling(format!("{}", err)))?
|
||||
""",
|
||||
"serializerFn" to serializerFn,
|
||||
*codegenScope
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.addStringHeader(name: String, valueExpr: String) {
|
||||
rustTemplate("headers.push(#{Header}::new(${name.dq()}, #{HeaderValue}::String($valueExpr)));", *codegenScope)
|
||||
}
|
||||
|
||||
private fun UnionShape.eventStreamMarshallerType(): RuntimeType {
|
||||
val symbol = symbolProvider.toSymbol(this)
|
||||
return RuntimeType("${symbol.name.toPascalCase()}Marshaller", null, "crate::event_stream_serde")
|
||||
}
|
||||
}
|
|
@ -141,7 +141,7 @@ class JsonSerializerGenerator(
|
|||
val target = model.expectShape(member.target, StructureShape::class.java)
|
||||
return RuntimeType.forInlineFun(fnName, "operation_ser") { writer ->
|
||||
writer.rustBlockTemplate(
|
||||
"pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>",
|
||||
"pub fn $fnName(input: &#{target}) -> std::result::Result<std::vec::Vec<u8>, #{Error}>",
|
||||
*codegenScope,
|
||||
"target" to symbolProvider.toSymbol(target)
|
||||
) {
|
||||
|
@ -149,7 +149,7 @@ class JsonSerializerGenerator(
|
|||
rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope)
|
||||
serializeStructure(StructContext("object", "input", target))
|
||||
rust("object.finish();")
|
||||
rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope)
|
||||
rustTemplate("Ok(out.into_bytes())", *codegenScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,6 +101,8 @@ abstract class QuerySerializerGenerator(protocolConfig: ProtocolConfig) : Struct
|
|||
}
|
||||
|
||||
override fun payloadSerializer(member: MemberShape): RuntimeType {
|
||||
// TODO(EventStream): [RPC] The query will need to be rendered to the initial message,
|
||||
// so this needs to be implemented
|
||||
TODO("The $protocolName protocol doesn't support http payload traits")
|
||||
}
|
||||
|
||||
|
|
|
@ -11,30 +11,30 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
|||
|
||||
interface StructuredDataSerializerGenerator {
|
||||
/**
|
||||
* Generate a parse function for a given targeted as a payload.
|
||||
* Entry point for payload-based parsing.
|
||||
* Roughly:
|
||||
* Generate a serializer for a request payload. Expected signature:
|
||||
* ```rust
|
||||
* fn serialize_some_payload(input: &PayloadSmithyType) -> Result<Vec<u8>, Error> {
|
||||
* ...
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
fun payloadSerializer(member: MemberShape): RuntimeType
|
||||
|
||||
/** Generate a serializer for operation input
|
||||
* Because only a subset of fields of the operation may be impacted by the document, a builder is passed
|
||||
* through:
|
||||
*
|
||||
/**
|
||||
* Generate a serializer for an operation input.
|
||||
* ```rust
|
||||
* fn parse_some_operation(inp: &[u8], builder: my_operation::Builder) -> Result<my_operation::Builder, XmlError> {
|
||||
* ...
|
||||
* fn serialize_some_operation(input: &SomeSmithyType) -> Result<SdkBody, Error> {
|
||||
* ...
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
fun operationSerializer(operationShape: OperationShape): RuntimeType?
|
||||
|
||||
/**
|
||||
* Generate a serializer for a document.
|
||||
* ```rust
|
||||
* fn parse_document(inp: &[u8]) -> Result<Document, Error> {
|
||||
* ...
|
||||
* fn serialize_document(input: &Document) -> Result<SdkBody, Error> {
|
||||
* ...
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
|
|
@ -121,7 +121,7 @@ class XmlBindingTraitSerializerGenerator(
|
|||
let mut writer = #{XmlWriter}::new(&mut out);
|
||||
##[allow(unused_mut)]
|
||||
let mut root = writer.start_el(${operationXmlName.dq()})${inputShape.xmlNamespace().apply()};
|
||||
""",
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
serializeStructure(inputShape, xmlMembers, Ctx.Element("root", "&input"))
|
||||
|
@ -140,10 +140,9 @@ class XmlBindingTraitSerializerGenerator(
|
|||
val target = model.expectShape(member.target, StructureShape::class.java)
|
||||
return RuntimeType.forInlineFun(fnName, "xml_ser") {
|
||||
val t = symbolProvider.toSymbol(member).rustType().stripOuter<RustType.Option>().render(true)
|
||||
it.rustBlock(
|
||||
"pub fn $fnName(input: &$t) -> Result<#T, String>",
|
||||
|
||||
RuntimeType.sdkBody(runtimeConfig),
|
||||
it.rustBlockTemplate(
|
||||
"pub fn $fnName(input: &$t) -> std::result::Result<std::vec::Vec<u8>, String>",
|
||||
*codegenScope
|
||||
) {
|
||||
rust("let mut out = String::new();")
|
||||
// create a scope for writer. This ensure that writer has been dropped before returning the
|
||||
|
@ -156,7 +155,7 @@ class XmlBindingTraitSerializerGenerator(
|
|||
let mut root = writer.start_el(${xmlIndex.payloadShapeName(member).dq()})${
|
||||
target.xmlNamespace().apply()
|
||||
};
|
||||
""",
|
||||
""",
|
||||
*codegenScope
|
||||
)
|
||||
serializeStructure(
|
||||
|
@ -165,7 +164,7 @@ class XmlBindingTraitSerializerGenerator(
|
|||
Ctx.Element("root", "&input")
|
||||
)
|
||||
}
|
||||
rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope)
|
||||
rustTemplate("Ok(out.into_bytes())", *codegenScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.traits
|
||||
|
||||
import software.amazon.smithy.model.node.Node
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.traits.AnnotationTrait
|
||||
|
||||
class SyntheticEventStreamUnionTrait(
|
||||
val errorMembers: List<MemberShape>,
|
||||
) : AnnotationTrait(ID, Node.objectNode()) {
|
||||
companion object {
|
||||
val ID = ShapeId.from("smithy.api.internal#syntheticEventStreamUnion")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.transformers
|
||||
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.model.traits.ErrorTrait
|
||||
import software.amazon.smithy.model.transform.ModelTransformer
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.isEventStream
|
||||
|
||||
/**
|
||||
* Generates synthetic unions to replace the modeled unions for Event Stream types.
|
||||
* This allows us to strip out all the error union members once up-front, instead of in each
|
||||
* place that does codegen with the unions.
|
||||
*/
|
||||
object EventStreamNormalizer {
|
||||
fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape ->
|
||||
if (shape is UnionShape && shape.isEventStream()) {
|
||||
syntheticEquivalent(model, shape)
|
||||
} else {
|
||||
shape
|
||||
}
|
||||
}
|
||||
|
||||
private fun syntheticEquivalent(model: Model, union: UnionShape): UnionShape {
|
||||
val (errorMembers, eventMembers) = union.members().partition { member ->
|
||||
model.expectShape(member.target).hasTrait<ErrorTrait>()
|
||||
}
|
||||
return union.toBuilder()
|
||||
.members(eventMembers)
|
||||
.addTrait(SyntheticEventStreamUnionTrait(errorMembers))
|
||||
.build()
|
||||
}
|
||||
}
|
|
@ -17,23 +17,25 @@ import software.amazon.smithy.rust.codegen.util.orNull
|
|||
import java.util.Optional
|
||||
import kotlin.streams.toList
|
||||
|
||||
typealias StructureModifier = (OperationShape, StructureShape?) -> StructureShape?
|
||||
|
||||
/**
|
||||
* Generate synthetic Input and Output structures for operations.
|
||||
*/
|
||||
class OperationNormalizer(private val model: Model) {
|
||||
object OperationNormalizer {
|
||||
// Functions to construct synthetic shape IDs—Don't rely on these in external code.
|
||||
// Rename safety: Operations cannot be renamed
|
||||
private fun OperationShape.syntheticInputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Input")
|
||||
private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Output")
|
||||
/**
|
||||
* Add synthetic input & output shapes to every Operation in model. The generated shapes will be marked with
|
||||
* [SyntheticInputTrait] and [SyntheticOutputTrait] respectively. Shapes will be added _even_ if the operation does
|
||||
* not specify an input or an output.
|
||||
*/
|
||||
fun transformModel(): Model {
|
||||
fun transform(model: Model): Model {
|
||||
val transformer = ModelTransformer.create()
|
||||
val operations = model.shapes(OperationShape::class.java).toList()
|
||||
val newShapes = operations.flatMap { operation ->
|
||||
// Generate or modify the input and output of the given `Operation` to be a unique shape
|
||||
syntheticInputShapes(operation) + syntheticOutputShapes(operation)
|
||||
syntheticInputShapes(model, operation) + syntheticOutputShapes(model, operation)
|
||||
}
|
||||
val modelWithOperationInputs = model.toBuilder().addShapes(newShapes).build()
|
||||
return transformer.mapShapes(modelWithOperationInputs) {
|
||||
|
@ -49,7 +51,7 @@ class OperationNormalizer(private val model: Model) {
|
|||
}
|
||||
}
|
||||
|
||||
private fun syntheticOutputShapes(operation: OperationShape): List<StructureShape> {
|
||||
private fun syntheticOutputShapes(model: Model, operation: OperationShape): List<StructureShape> {
|
||||
val outputId = operation.syntheticOutputId()
|
||||
val outputShapeBuilder = operation.output.map { shapeId ->
|
||||
model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(outputId)
|
||||
|
@ -63,7 +65,7 @@ class OperationNormalizer(private val model: Model) {
|
|||
return listOfNotNull(outputShape)
|
||||
}
|
||||
|
||||
private fun syntheticInputShapes(operation: OperationShape): List<StructureShape> {
|
||||
private fun syntheticInputShapes(model: Model, operation: OperationShape): List<StructureShape> {
|
||||
val inputId = operation.syntheticInputId()
|
||||
val inputShapeBuilder = operation.input.map { shapeId ->
|
||||
model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(inputId)
|
||||
|
@ -79,13 +81,6 @@ class OperationNormalizer(private val model: Model) {
|
|||
}
|
||||
|
||||
private fun empty(id: ShapeId) = StructureShape.builder().id(id)
|
||||
|
||||
companion object {
|
||||
// Functions to construct synthetic shape IDs—Don't rely on these in external code.
|
||||
// Rename safety: Operations cannot be renamed
|
||||
private fun OperationShape.syntheticInputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Input")
|
||||
private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Output")
|
||||
}
|
||||
}
|
||||
|
||||
private fun StructureShape.Builder.rename(newId: ShapeId): StructureShape.Builder {
|
||||
|
|
|
@ -13,23 +13,38 @@ import software.amazon.smithy.rust.codegen.util.findStreamingMember
|
|||
import software.amazon.smithy.rust.codegen.util.orNull
|
||||
import java.util.logging.Logger
|
||||
|
||||
// TODO(EventStream): Remove this class once the Event Stream implementation is stable
|
||||
/** Transformer to REMOVE operations that use EventStreaming until event streaming is supported */
|
||||
object RemoveEventStreamOperations {
|
||||
private val logger = Logger.getLogger(javaClass.name)
|
||||
fun transform(model: Model): Model = ModelTransformer.create().filterShapes(model) { parentShape ->
|
||||
if (parentShape !is OperationShape) {
|
||||
true
|
||||
} else {
|
||||
val ioShapes = listOfNotNull(parentShape.output.orNull(), parentShape.input.orNull()).map { model.expectShape(it, StructureShape::class.java) }
|
||||
val hasEventStream = ioShapes.any { ioShape ->
|
||||
val streamingMember = ioShape.findStreamingMember(model)?.let { model.expectShape(it.target) }
|
||||
streamingMember?.isUnionShape ?: false
|
||||
}
|
||||
// If a streaming member has a union trait, it is an event stream. Event Streams are not currently supported
|
||||
// by the SDK, so if we generate this API it won't work.
|
||||
(!hasEventStream).also {
|
||||
if (!it) {
|
||||
logger.info("Removed $parentShape from model because it targets an event stream")
|
||||
|
||||
private fun eventStreamEnabled(): Boolean =
|
||||
System.getenv()["SMITHYRS_EXPERIMENTAL_EVENTSTREAM"] == "1"
|
||||
|
||||
fun transform(model: Model): Model {
|
||||
if (eventStreamEnabled()) {
|
||||
return model
|
||||
}
|
||||
return ModelTransformer.create().filterShapes(model) { parentShape ->
|
||||
if (parentShape !is OperationShape) {
|
||||
true
|
||||
} else {
|
||||
val ioShapes = listOfNotNull(parentShape.output.orNull(), parentShape.input.orNull()).map {
|
||||
model.expectShape(
|
||||
it,
|
||||
StructureShape::class.java
|
||||
)
|
||||
}
|
||||
val hasEventStream = ioShapes.any { ioShape ->
|
||||
val streamingMember = ioShape.findStreamingMember(model)?.let { model.expectShape(it.target) }
|
||||
streamingMember?.isUnionShape ?: false
|
||||
}
|
||||
// If a streaming member has a union trait, it is an event stream. Event Streams are not currently supported
|
||||
// by the SDK, so if we generate this API it won't work.
|
||||
(!hasEventStream).also {
|
||||
if (!it) {
|
||||
logger.info("Removed $parentShape from model because it targets an event stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -157,7 +157,12 @@ class TestWriterDelegator(fileManifest: FileManifest, symbolProvider: RustSymbol
|
|||
val baseDir: Path = fileManifest.baseDir
|
||||
}
|
||||
|
||||
fun TestWriterDelegator.compileAndTest() {
|
||||
/**
|
||||
* Setting `runClippy` to true can be helpful when debugging clippy failures, but
|
||||
* should generally be set to false to avoid invalidating the Cargo cache between
|
||||
* every unit test run.
|
||||
*/
|
||||
fun TestWriterDelegator.compileAndTest(runClippy: Boolean = false) {
|
||||
val stubModel = """
|
||||
namespace fake
|
||||
service Fake {
|
||||
|
@ -183,6 +188,9 @@ fun TestWriterDelegator.compileAndTest() {
|
|||
// cargo fmt errors are useless, ignore
|
||||
}
|
||||
"cargo test".runCommand(baseDir, mapOf("RUSTFLAGS" to "-A dead_code"))
|
||||
if (runClippy) {
|
||||
"cargo clippy".runCommand(baseDir)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: unify these test helpers a bit
|
||||
|
|
|
@ -11,12 +11,14 @@ import software.amazon.smithy.model.shapes.BooleanShape
|
|||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.NumberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.model.traits.StreamingTrait
|
||||
import software.amazon.smithy.model.traits.Trait
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
||||
|
||||
inline fun <reified T : Shape> Model.lookup(shapeId: String): T {
|
||||
return this.expectShape(ShapeId.from(shapeId), T::class.java)
|
||||
|
@ -42,6 +44,37 @@ fun StructureShape.hasStreamingMember(model: Model) = this.findStreamingMember(m
|
|||
fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait<StreamingTrait>(model) != null
|
||||
fun MemberShape.isStreaming(model: Model) = this.getMemberTrait(model, StreamingTrait::class.java).isPresent
|
||||
|
||||
fun UnionShape.isEventStream(): Boolean {
|
||||
return hasTrait(StreamingTrait::class.java)
|
||||
}
|
||||
fun MemberShape.isEventStream(model: Model): Boolean {
|
||||
return (model.expectShape(target) as? UnionShape)?.isEventStream() ?: false
|
||||
}
|
||||
fun MemberShape.isInputEventStream(model: Model): Boolean {
|
||||
return isEventStream(model) && model.expectShape(container).hasTrait<SyntheticInputTrait>()
|
||||
}
|
||||
fun MemberShape.isOutputEventStream(model: Model): Boolean {
|
||||
return isEventStream(model) && model.expectShape(container).hasTrait<SyntheticInputTrait>()
|
||||
}
|
||||
private fun Shape.hasEventStreamMember(model: Model): Boolean {
|
||||
return members().any { it.isEventStream(model) }
|
||||
}
|
||||
fun OperationShape.isInputEventStream(model: Model): Boolean {
|
||||
return input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false)
|
||||
}
|
||||
fun OperationShape.isOutputEventStream(model: Model): Boolean {
|
||||
return output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false)
|
||||
}
|
||||
fun OperationShape.isEventStream(model: Model): Boolean {
|
||||
return isInputEventStream(model) || isOutputEventStream(model)
|
||||
}
|
||||
fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = operations.any { id ->
|
||||
// Don't assume all of the looked up operation ids are operation shapes. Our
|
||||
// synthetic input/output structure shapes can have the same name as an operation,
|
||||
// as is the case with `kinesisanalytics`.
|
||||
model.getShape(id).orNull()?.let { it is OperationShape && it.isEventStream(model) } ?: false
|
||||
}
|
||||
|
||||
/*
|
||||
* Returns the member of this structure targeted with streaming trait (if it exists).
|
||||
*
|
||||
|
|
|
@ -101,7 +101,7 @@ class RequestBindingGeneratorTest {
|
|||
stringHeader: String
|
||||
}
|
||||
""".asSmithyModel()
|
||||
private val model = OperationNormalizer(baseModel).transformModel()
|
||||
private val model = OperationNormalizer.transform(baseModel)
|
||||
|
||||
private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java)
|
||||
private val inputShape = model.expectShape(operationShape.input.get(), StructureShape::class.java)
|
||||
|
|
|
@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
|
|||
import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.testutil.TestWorkspace
|
||||
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
|
||||
|
@ -65,7 +66,7 @@ class ResponseBindingGeneratorTest {
|
|||
additional: String,
|
||||
}
|
||||
""".asSmithyModel()
|
||||
private val model = OperationNormalizer(baseModel).transformModel()
|
||||
private val model = OperationNormalizer.transform(baseModel)
|
||||
private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java)
|
||||
private val symbolProvider = testSymbolProvider(model)
|
||||
private val testProtocolConfig: ProtocolConfig = testProtocolConfig(model)
|
||||
|
@ -78,7 +79,9 @@ class ResponseBindingGeneratorTest {
|
|||
.filter { it.location == HttpLocation.HEADER }
|
||||
bindings.forEach { binding ->
|
||||
val runtimeType = ResponseBindingGenerator(
|
||||
testProtocolConfig, operationShape
|
||||
RestJson(testProtocolConfig),
|
||||
testProtocolConfig,
|
||||
operationShape
|
||||
).generateDeserializeHeaderFn(binding)
|
||||
// little hack to force these functions to be generated
|
||||
rust("// use #T;", runtimeType)
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy
|
||||
|
||||
import io.kotest.matchers.shouldBe
|
||||
import org.junit.jupiter.api.Test
|
||||
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.rustlang.CratesIo
|
||||
import software.amazon.smithy.rust.codegen.rustlang.DependencyScope.Compile
|
||||
|
||||
class CodegenDelegatorTest {
|
||||
@Test
|
||||
fun testMergeDependencyFeatures() {
|
||||
val merged = mergeDependencyFeatures(
|
||||
listOf(
|
||||
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf()),
|
||||
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1")),
|
||||
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f2")),
|
||||
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")),
|
||||
|
||||
CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()),
|
||||
CargoDependency("B", CratesIo("2"), Compile, optional = true, features = setOf()),
|
||||
|
||||
CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()),
|
||||
CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()),
|
||||
).shuffled()
|
||||
)
|
||||
|
||||
merged shouldBe setOf(
|
||||
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")),
|
||||
CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()),
|
||||
CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy
|
||||
|
||||
import io.kotest.matchers.shouldBe
|
||||
import org.junit.jupiter.api.Test
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.rust.codegen.rustlang.RustType
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
|
||||
|
||||
class EventStreamSymbolProviderTest {
|
||||
@Test
|
||||
fun `it should adjust types for operations with event streams`() {
|
||||
// Transform the model so that it has synthetic inputs/outputs
|
||||
val model = OperationNormalizer.transform(
|
||||
"""
|
||||
namespace test
|
||||
|
||||
structure Something { stuff: Blob }
|
||||
|
||||
@streaming
|
||||
union SomeStream {
|
||||
Something: Something,
|
||||
}
|
||||
|
||||
structure TestInput { inputStream: SomeStream }
|
||||
structure TestOutput { outputStream: SomeStream }
|
||||
operation TestOperation {
|
||||
input: TestInput,
|
||||
output: TestOutput,
|
||||
}
|
||||
service TestService { version: "123", operations: [TestOperation] }
|
||||
""".asSmithyModel()
|
||||
)
|
||||
|
||||
val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
|
||||
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model)
|
||||
|
||||
// Look up the synthetic input/output rather than the original input/output
|
||||
val inputStream = model.expectShape(ShapeId.from("test#TestOperationInput\$inputStream")) as MemberShape
|
||||
val outputStream = model.expectShape(ShapeId.from("test#TestOperationOutput\$outputStream")) as MemberShape
|
||||
|
||||
val inputType = provider.toSymbol(inputStream).rustType()
|
||||
val outputType = provider.toSymbol(outputStream).rustType()
|
||||
|
||||
inputType shouldBe RustType.Opaque("EventStreamInput<crate::model::SomeStream>", "smithy_http::event_stream")
|
||||
outputType shouldBe RustType.Opaque("Receiver<crate::model::SomeStream, crate::error::TestOperationError>", "smithy_http::event_stream")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `it should leave alone types for operations without event streams`() {
|
||||
val model = OperationNormalizer.transform(
|
||||
"""
|
||||
namespace test
|
||||
|
||||
structure Something { stuff: Blob }
|
||||
|
||||
union NotStreaming {
|
||||
Something: Something,
|
||||
}
|
||||
|
||||
structure TestInput { inputStream: NotStreaming }
|
||||
structure TestOutput { outputStream: NotStreaming }
|
||||
operation TestOperation {
|
||||
input: TestInput,
|
||||
output: TestOutput,
|
||||
}
|
||||
service TestService { version: "123", operations: [TestOperation] }
|
||||
""".asSmithyModel()
|
||||
)
|
||||
|
||||
val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
|
||||
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model)
|
||||
|
||||
// Look up the synthetic input/output rather than the original input/output
|
||||
val inputStream = model.expectShape(ShapeId.from("test#TestOperationInput\$inputStream")) as MemberShape
|
||||
val outputStream = model.expectShape(ShapeId.from("test#TestOperationOutput\$outputStream")) as MemberShape
|
||||
|
||||
val inputType = provider.toSymbol(inputStream).rustType()
|
||||
val outputType = provider.toSymbol(outputStream).rustType()
|
||||
|
||||
inputType shouldBe RustType.Option(RustType.Opaque("NotStreaming", "crate::model"))
|
||||
outputType shouldBe RustType.Option(RustType.Opaque("NotStreaming", "crate::model"))
|
||||
}
|
||||
}
|
|
@ -34,7 +34,7 @@ internal class StreamingShapeSymbolProviderTest {
|
|||
fun `generates a byte stream on streaming output`() {
|
||||
// we could test exactly the streaming shape symbol provider, but we actually care about is the full stack
|
||||
// "doing the right thing"
|
||||
val modelWithOperationTraits = OperationNormalizer(model).transformModel()
|
||||
val modelWithOperationTraits = OperationNormalizer.transform(model)
|
||||
val symbolProvider = testSymbolProvider(modelWithOperationTraits)
|
||||
symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test#GenerateSpeechOutput\$data")).name shouldBe ("byte_stream::ByteStream")
|
||||
symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test#GenerateSpeechInput\$data")).name shouldBe ("byte_stream::ByteStream")
|
||||
|
@ -42,7 +42,7 @@ internal class StreamingShapeSymbolProviderTest {
|
|||
|
||||
@Test
|
||||
fun `streaming members have a default`() {
|
||||
val modelWithOperationTraits = OperationNormalizer(model).transformModel()
|
||||
val modelWithOperationTraits = OperationNormalizer.transform(model)
|
||||
val symbolProvider = testSymbolProvider(modelWithOperationTraits)
|
||||
|
||||
val outputSymbol = symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test#GenerateSpeechOutput\$data"))
|
||||
|
|
|
@ -22,8 +22,6 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
|||
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
|
||||
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
|
||||
import software.amazon.smithy.rust.codegen.testutil.generatePluginContext
|
||||
import software.amazon.smithy.rust.codegen.util.CommandFailed
|
||||
|
@ -163,9 +161,7 @@ class HttpProtocolTestGeneratorTest {
|
|||
return TestProtocol(protocolConfig)
|
||||
}
|
||||
|
||||
override fun transformModel(model: Model): Model {
|
||||
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
|
||||
}
|
||||
override fun transformModel(model: Model): Model = model
|
||||
|
||||
override fun support(): ProtocolSupport {
|
||||
return ProtocolSupport(true, true, true, true)
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.protocols
|
||||
|
||||
import org.junit.jupiter.api.extension.ExtensionContext
|
||||
import org.junit.jupiter.params.provider.Arguments
|
||||
import org.junit.jupiter.params.provider.ArgumentsProvider
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.model.traits.ErrorTrait
|
||||
import software.amazon.smithy.rust.codegen.rustlang.RustModule
|
||||
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
|
||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.testutil.TestWorkspace
|
||||
import software.amazon.smithy.rust.codegen.testutil.TestWriterDelegator
|
||||
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
|
||||
import software.amazon.smithy.rust.codegen.testutil.renderWithModelBuilder
|
||||
import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.util.lookup
|
||||
import software.amazon.smithy.rust.codegen.util.outputShape
|
||||
import java.util.stream.Stream
|
||||
|
||||
private fun fillInBaseModel(
|
||||
protocolName: String,
|
||||
extraServiceAnnotations: String = "",
|
||||
): String = """
|
||||
namespace test
|
||||
|
||||
use aws.protocols#$protocolName
|
||||
|
||||
union TestUnion {
|
||||
Foo: String,
|
||||
Bar: Integer,
|
||||
}
|
||||
structure TestStruct {
|
||||
someString: String,
|
||||
someInt: Integer,
|
||||
}
|
||||
|
||||
@error("client")
|
||||
structure SomeError {
|
||||
Message: String,
|
||||
}
|
||||
|
||||
structure MessageWithBlob { @eventPayload data: Blob }
|
||||
structure MessageWithString { @eventPayload data: String }
|
||||
structure MessageWithStruct { @eventPayload someStruct: TestStruct }
|
||||
structure MessageWithUnion { @eventPayload someUnion: TestUnion }
|
||||
structure MessageWithHeaders {
|
||||
@eventHeader blob: Blob,
|
||||
@eventHeader boolean: Boolean,
|
||||
@eventHeader byte: Byte,
|
||||
@eventHeader int: Integer,
|
||||
@eventHeader long: Long,
|
||||
@eventHeader short: Short,
|
||||
@eventHeader string: String,
|
||||
@eventHeader timestamp: Timestamp,
|
||||
}
|
||||
structure MessageWithHeaderAndPayload {
|
||||
@eventHeader header: String,
|
||||
@eventPayload payload: Blob,
|
||||
}
|
||||
structure MessageWithNoHeaderPayloadTraits {
|
||||
someInt: Integer,
|
||||
someString: String,
|
||||
}
|
||||
|
||||
@streaming
|
||||
union TestStream {
|
||||
MessageWithBlob: MessageWithBlob,
|
||||
MessageWithString: MessageWithString,
|
||||
MessageWithStruct: MessageWithStruct,
|
||||
MessageWithUnion: MessageWithUnion,
|
||||
MessageWithHeaders: MessageWithHeaders,
|
||||
MessageWithHeaderAndPayload: MessageWithHeaderAndPayload,
|
||||
MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits,
|
||||
SomeError: SomeError,
|
||||
}
|
||||
structure TestStreamInputOutput { @required value: TestStream }
|
||||
operation TestStreamOp {
|
||||
input: TestStreamInputOutput,
|
||||
output: TestStreamInputOutput,
|
||||
errors: [SomeError],
|
||||
}
|
||||
$extraServiceAnnotations
|
||||
@$protocolName
|
||||
service TestService { version: "123", operations: [TestStreamOp] }
|
||||
"""
|
||||
|
||||
object EventStreamTestModels {
|
||||
fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
|
||||
fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
|
||||
fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
|
||||
fun awsQuery(): Model = fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
|
||||
fun ec2Query(): Model = fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
|
||||
|
||||
data class TestCase(
|
||||
val protocolShapeId: String,
|
||||
val model: Model,
|
||||
val contentType: String,
|
||||
val validTestStruct: String,
|
||||
val validMessageWithNoHeaderPayloadTraits: String,
|
||||
val validTestUnion: String,
|
||||
val validSomeError: String,
|
||||
val protocolBuilder: (ProtocolConfig) -> Protocol,
|
||||
) {
|
||||
override fun toString(): String = protocolShapeId
|
||||
}
|
||||
|
||||
class ModelArgumentsProvider : ArgumentsProvider {
|
||||
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
|
||||
Stream.of(
|
||||
Arguments.of(
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#restJson1",
|
||||
model = restJson1(),
|
||||
contentType = "application/json",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
|
||||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
) { RestJson(it) }
|
||||
),
|
||||
Arguments.of(
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#awsJson1_1",
|
||||
model = awsJson11(),
|
||||
contentType = "application/x-amz-json-1.1",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
|
||||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
) { AwsJson(it, AwsJsonVersion.Json11) }
|
||||
),
|
||||
Arguments.of(
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#restXml",
|
||||
model = restXml(),
|
||||
contentType = "text/xml",
|
||||
validTestStruct = """
|
||||
<TestStruct>
|
||||
<someString>hello</someString>
|
||||
<someInt>5</someInt>
|
||||
</TestStruct>
|
||||
""".trimIndent(),
|
||||
validMessageWithNoHeaderPayloadTraits = """
|
||||
<MessageWithNoHeaderPayloadTraits>
|
||||
<someString>hello</someString>
|
||||
<someInt>5</someInt>
|
||||
</MessageWithNoHeaderPayloadTraits>
|
||||
""".trimIndent(),
|
||||
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
|
||||
validSomeError = """
|
||||
<ErrorResponse>
|
||||
<Error>
|
||||
<Type>SomeError</Type>
|
||||
<Code>SomeError</Code>
|
||||
<Message>some error</Message>
|
||||
</Error>
|
||||
</ErrorResponse>
|
||||
""".trimIndent()
|
||||
) { RestXml(it) }
|
||||
),
|
||||
Arguments.of(
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#awsQuery",
|
||||
model = awsQuery(),
|
||||
contentType = "application/x-www-form-urlencoded",
|
||||
validTestStruct = """
|
||||
<TestStruct>
|
||||
<someString>hello</someString>
|
||||
<someInt>5</someInt>
|
||||
</TestStruct>
|
||||
""".trimIndent(),
|
||||
validMessageWithNoHeaderPayloadTraits = """
|
||||
<MessageWithNoHeaderPayloadTraits>
|
||||
<someString>hello</someString>
|
||||
<someInt>5</someInt>
|
||||
</MessageWithNoHeaderPayloadTraits>
|
||||
""".trimIndent(),
|
||||
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
|
||||
validSomeError = """
|
||||
<ErrorResponse>
|
||||
<Error>
|
||||
<Type>SomeError</Type>
|
||||
<Code>SomeError</Code>
|
||||
<Message>some error</Message>
|
||||
</Error>
|
||||
</ErrorResponse>
|
||||
""".trimIndent()
|
||||
) { AwsQueryProtocol(it) }
|
||||
),
|
||||
Arguments.of(
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#ec2Query",
|
||||
model = ec2Query(),
|
||||
contentType = "application/x-www-form-urlencoded",
|
||||
validTestStruct = """
|
||||
<TestStruct>
|
||||
<someString>hello</someString>
|
||||
<someInt>5</someInt>
|
||||
</TestStruct>
|
||||
""".trimIndent(),
|
||||
validMessageWithNoHeaderPayloadTraits = """
|
||||
<MessageWithNoHeaderPayloadTraits>
|
||||
<someString>hello</someString>
|
||||
<someInt>5</someInt>
|
||||
</MessageWithNoHeaderPayloadTraits>
|
||||
""".trimIndent(),
|
||||
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
|
||||
validSomeError = """
|
||||
<Response>
|
||||
<Errors>
|
||||
<Error>
|
||||
<Type>SomeError</Type>
|
||||
<Code>SomeError</Code>
|
||||
<Message>some error</Message>
|
||||
</Error>
|
||||
</Error>
|
||||
</Response>
|
||||
""".trimIndent()
|
||||
) { Ec2QueryProtocol(it) }
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
data class TestEventStreamProject(
|
||||
val model: Model,
|
||||
val serviceShape: ServiceShape,
|
||||
val operationShape: OperationShape,
|
||||
val streamShape: UnionShape,
|
||||
val symbolProvider: RustSymbolProvider,
|
||||
val project: TestWriterDelegator,
|
||||
)
|
||||
|
||||
object EventStreamTestTools {
|
||||
fun generateTestProject(model: Model): TestEventStreamProject {
|
||||
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(model))
|
||||
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
|
||||
val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape
|
||||
val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape
|
||||
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val project = TestWorkspace.testProject(symbolProvider)
|
||||
project.withModule(RustModule.default("error", public = true)) {
|
||||
CombinedErrorGenerator(model, symbolProvider, operationShape).render(it)
|
||||
for (shape in model.shapes().filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }) {
|
||||
StructureGenerator(model, symbolProvider, it, shape as StructureShape).render()
|
||||
val builderGen = BuilderGenerator(model, symbolProvider, shape)
|
||||
builderGen.render(it)
|
||||
it.implBlock(shape, symbolProvider) {
|
||||
builderGen.renderConvenienceMethod(this)
|
||||
}
|
||||
}
|
||||
}
|
||||
project.withModule(RustModule.default("model", public = true)) {
|
||||
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
|
||||
recursivelyGenerateModels(model, symbolProvider, inputOutput, it)
|
||||
}
|
||||
project.withModule(RustModule.default("output", public = true)) {
|
||||
operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, it)
|
||||
}
|
||||
println("file:///${project.baseDir}/src/error.rs")
|
||||
println("file:///${project.baseDir}/src/event_stream.rs")
|
||||
println("file:///${project.baseDir}/src/event_stream_serde.rs")
|
||||
println("file:///${project.baseDir}/src/lib.rs")
|
||||
println("file:///${project.baseDir}/src/model.rs")
|
||||
return TestEventStreamProject(model, serviceShape, operationShape, unionShape, symbolProvider, project)
|
||||
}
|
||||
|
||||
private fun recursivelyGenerateModels(
|
||||
model: Model,
|
||||
symbolProvider: RustSymbolProvider,
|
||||
shape: Shape,
|
||||
writer: RustWriter
|
||||
) {
|
||||
for (member in shape.members()) {
|
||||
val target = model.expectShape(member.target)
|
||||
if (target is StructureShape || target is UnionShape) {
|
||||
if (target is StructureShape) {
|
||||
target.renderWithModelBuilder(model, symbolProvider, writer)
|
||||
} else if (target is UnionShape) {
|
||||
UnionGenerator(model, symbolProvider, writer, target).render()
|
||||
}
|
||||
recursivelyGenerateModels(model, symbolProvider, target, writer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -42,7 +42,7 @@ class AwsQueryParserGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `it modifies operation parsing to include Response and Result tags`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = AwsQueryParserGenerator(
|
||||
testProtocolConfig(model),
|
||||
|
|
|
@ -42,7 +42,7 @@ class Ec2QueryParserGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `it modifies operation parsing to include Response and Result tags`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = Ec2QueryParserGenerator(
|
||||
testProtocolConfig(model),
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.protocols.parse
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.ArgumentsSource
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.rust.codegen.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestTools
|
||||
import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.testutil.compileAndTest
|
||||
import software.amazon.smithy.rust.codegen.testutil.unitTest
|
||||
|
||||
class EventStreamUnmarshallerGeneratorTest {
|
||||
@ParameterizedTest
|
||||
@ArgumentsSource(EventStreamTestModels.ModelArgumentsProvider::class)
|
||||
fun test(testCase: EventStreamTestModels.TestCase) {
|
||||
val test = EventStreamTestTools.generateTestProject(testCase.model)
|
||||
|
||||
val protocolConfig = ProtocolConfig(
|
||||
test.model,
|
||||
test.symbolProvider,
|
||||
TestRuntimeConfig,
|
||||
test.serviceShape,
|
||||
ShapeId.from(testCase.protocolShapeId),
|
||||
"test"
|
||||
)
|
||||
val protocol = testCase.protocolBuilder(protocolConfig)
|
||||
val generator = EventStreamUnmarshallerGenerator(
|
||||
protocol,
|
||||
test.model,
|
||||
TestRuntimeConfig,
|
||||
test.symbolProvider,
|
||||
test.operationShape,
|
||||
test.streamShape
|
||||
)
|
||||
|
||||
test.project.lib { writer ->
|
||||
// TODO(EventStream): Add test for bad content type
|
||||
// TODO(EventStream): Add test for generic error parsing
|
||||
writer.rust(
|
||||
"""
|
||||
use smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage};
|
||||
use smithy_types::{Blob, Instant};
|
||||
use crate::error::*;
|
||||
use crate::model::*;
|
||||
|
||||
fn msg(
|
||||
message_type: &'static str,
|
||||
event_type: &'static str,
|
||||
content_type: &'static str,
|
||||
payload: &'static [u8],
|
||||
) -> Message {
|
||||
let message = Message::new(payload)
|
||||
.add_header(Header::new(":message-type", HeaderValue::String(message_type.into())))
|
||||
.add_header(Header::new(":content-type", HeaderValue::String(content_type.into())));
|
||||
if message_type == "event" {
|
||||
message.add_header(Header::new(":event-type", HeaderValue::String(event_type.into())))
|
||||
} else {
|
||||
message.add_header(Header::new(":exception-type", HeaderValue::String(event_type.into())))
|
||||
}
|
||||
}
|
||||
fn expect_event<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> T {
|
||||
match unmarshalled {
|
||||
UnmarshalledMessage::Event(event) => event,
|
||||
_ => panic!("expected event, got: {:?}", unmarshalled),
|
||||
}
|
||||
}
|
||||
fn expect_error<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> E {
|
||||
match unmarshalled {
|
||||
UnmarshalledMessage::Error(error) => error,
|
||||
_ => panic!("expected error, got: {:?}", unmarshalled),
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!");
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithBlob(
|
||||
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_blob",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg("event", "MessageWithString", "application/octet-stream", b"hello, world!");
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_string",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithStruct",
|
||||
"${testCase.contentType}",
|
||||
br#"${testCase.validTestStruct}"#
|
||||
);
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(
|
||||
TestStruct::builder()
|
||||
.some_string("hello")
|
||||
.some_int(5)
|
||||
.build()
|
||||
).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_struct",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithUnion",
|
||||
"${testCase.contentType}",
|
||||
br#"${testCase.validTestUnion}"#
|
||||
);
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
|
||||
TestUnion::Foo("hello".into())
|
||||
).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_union",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
|
||||
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
|
||||
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
|
||||
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
|
||||
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
|
||||
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
|
||||
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
|
||||
.add_header(Header::new("string", HeaderValue::String("test".into())))
|
||||
.add_header(Header::new("timestamp", HeaderValue::Timestamp(Instant::from_epoch_seconds(5))));
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
|
||||
.blob(Blob::new(&b"test"[..]))
|
||||
.boolean(true)
|
||||
.byte(55i8)
|
||||
.int(100_000i32)
|
||||
.long(9_000_000_000i64)
|
||||
.short(16_000i16)
|
||||
.string("test")
|
||||
.timestamp(Instant::from_epoch_seconds(5))
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_headers",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
|
||||
.add_header(Header::new("header", HeaderValue::String("header".into())));
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
|
||||
.header("header")
|
||||
.payload(Blob::new(&b"payload"[..]))
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_header_and_payload",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithNoHeaderPayloadTraits",
|
||||
"${testCase.contentType}",
|
||||
br#"${testCase.validMessageWithNoHeaderPayloadTraits}"#
|
||||
);
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
|
||||
.some_int(5)
|
||||
.some_string("hello")
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"message_with_no_header_payload_traits",
|
||||
)
|
||||
|
||||
writer.unitTest(
|
||||
"""
|
||||
let message = msg(
|
||||
"exception",
|
||||
"SomeError",
|
||||
"${testCase.contentType}",
|
||||
br#"${testCase.validSomeError}"#
|
||||
);
|
||||
let result = ${writer.format(generator.render())}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
match expect_error(result.unwrap()).kind {
|
||||
TestStreamOpErrorKind::SomeError(err) => assert_eq!(Some("some error"), err.message()),
|
||||
kind => panic!("expected SomeError, but got {:?}", kind),
|
||||
}
|
||||
""",
|
||||
"some_error",
|
||||
)
|
||||
}
|
||||
test.project.compileAndTest()
|
||||
}
|
||||
}
|
|
@ -106,7 +106,7 @@ class JsonParserGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `generates valid deserializers`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = JsonParserGenerator(
|
||||
testProtocolConfig(model),
|
||||
|
|
|
@ -90,7 +90,7 @@ internal class XmlBindingTraitParserGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `generates valid parsers`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = XmlBindingTraitParserGenerator(
|
||||
testProtocolConfig(model),
|
||||
|
|
|
@ -81,7 +81,7 @@ class AwsQuerySerializerGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `generates valid serializers`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = AwsQuerySerializerGenerator(testProtocolConfig(model))
|
||||
val operationGenerator = parserGenerator.operationSerializer(model.lookup("test#Op"))
|
||||
|
|
|
@ -80,7 +80,7 @@ class Ec2QuerySerializerGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `generates valid serializers`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = Ec2QuerySerializerGenerator(testProtocolConfig(model))
|
||||
val operationGenerator = parserGenerator.operationSerializer(model.lookup("test#Op"))
|
||||
|
|
|
@ -98,7 +98,7 @@ class JsonSerializerGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `generates valid serializers`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserSerializer = JsonSerializerGenerator(
|
||||
testProtocolConfig(model),
|
||||
|
|
|
@ -103,7 +103,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
|
|||
|
||||
@Test
|
||||
fun `generates valid serializers`() {
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
|
||||
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
|
||||
val symbolProvider = testSymbolProvider(model)
|
||||
val parserGenerator = XmlBindingTraitSerializerGenerator(
|
||||
testProtocolConfig(model),
|
||||
|
@ -124,7 +124,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
|
|||
.build()
|
||||
).build().unwrap();
|
||||
let serialized = ${writer.format(operationParser)}(&inp.payload.unwrap()).unwrap();
|
||||
let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap();
|
||||
let output = std::str::from_utf8(&serialized).unwrap();
|
||||
assert_eq!(output, "<Top extra=\"45\"><field>hello!</field><recursive extra=\"55\"></recursive></Top>");
|
||||
"""
|
||||
)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.smithy.transformers
|
||||
|
||||
import io.kotest.matchers.shouldBe
|
||||
import org.junit.jupiter.api.Test
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait
|
||||
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
|
||||
import software.amazon.smithy.rust.codegen.util.expectTrait
|
||||
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||
|
||||
class EventStreamNormalizerTest {
|
||||
@Test
|
||||
fun `it should leave normal unions alone`() {
|
||||
val transformed = EventStreamNormalizer.transform(
|
||||
"""
|
||||
namespace test
|
||||
union SomeNormalUnion {
|
||||
Foo: String,
|
||||
Bar: Long,
|
||||
}
|
||||
""".asSmithyModel()
|
||||
)
|
||||
|
||||
val shape = transformed.expectShape(ShapeId.from("test#SomeNormalUnion"), UnionShape::class.java)
|
||||
shape.hasTrait<SyntheticEventStreamUnionTrait>() shouldBe false
|
||||
shape.memberNames shouldBe listOf("Foo", "Bar")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `it should transform event stream unions`() {
|
||||
val transformed = EventStreamNormalizer.transform(
|
||||
"""
|
||||
namespace test
|
||||
|
||||
structure SomeMember {
|
||||
}
|
||||
|
||||
@error("client")
|
||||
structure SomeError {
|
||||
}
|
||||
|
||||
@streaming
|
||||
union SomeEventStream {
|
||||
SomeMember: SomeMember,
|
||||
SomeError: SomeError,
|
||||
}
|
||||
""".asSmithyModel()
|
||||
)
|
||||
|
||||
val shape = transformed.expectShape(ShapeId.from("test#SomeEventStream"), UnionShape::class.java)
|
||||
shape.hasTrait<SyntheticEventStreamUnionTrait>() shouldBe true
|
||||
shape.memberNames shouldBe listOf("SomeMember")
|
||||
|
||||
val trait = shape.expectTrait<SyntheticEventStreamUnionTrait>()
|
||||
trait.errorMembers.map { it.memberName } shouldBe listOf("SomeError")
|
||||
}
|
||||
}
|
|
@ -26,8 +26,7 @@ internal class OperationNormalizerTest {
|
|||
""".asSmithyModel()
|
||||
val operationId = ShapeId.from("smithy.test#Empty")
|
||||
model.expectShape(operationId, OperationShape::class.java).input.isPresent shouldBe false
|
||||
val sut = OperationNormalizer(model)
|
||||
val modified = sut.transformModel()
|
||||
val modified = OperationNormalizer.transform(model)
|
||||
val operation = modified.expectShape(operationId, OperationShape::class.java)
|
||||
operation.input.isPresent shouldBe true
|
||||
operation.input.get().name shouldBe "EmptyInput"
|
||||
|
@ -55,8 +54,7 @@ internal class OperationNormalizerTest {
|
|||
""".asSmithyModel()
|
||||
val operationId = ShapeId.from("smithy.test#MyOp")
|
||||
model.expectShape(operationId, OperationShape::class.java).input.isPresent shouldBe true
|
||||
val sut = OperationNormalizer(model)
|
||||
val modified = sut.transformModel()
|
||||
val modified = OperationNormalizer.transform(model)
|
||||
val operation = modified.expectShape(operationId, OperationShape::class.java)
|
||||
operation.input.isPresent shouldBe true
|
||||
val inputId = operation.input.get()
|
||||
|
@ -79,8 +77,7 @@ internal class OperationNormalizerTest {
|
|||
""".asSmithyModel()
|
||||
val operationId = ShapeId.from("smithy.test#MyOp")
|
||||
model.expectShape(operationId, OperationShape::class.java).output.isPresent shouldBe true
|
||||
val sut = OperationNormalizer(model)
|
||||
val modified = sut.transformModel()
|
||||
val modified = OperationNormalizer.transform(model)
|
||||
val operation = modified.expectShape(operationId, OperationShape::class.java)
|
||||
operation.output.isPresent shouldBe true
|
||||
val outputId = operation.output.get()
|
||||
|
|
|
@ -11,8 +11,8 @@ are to allow this crate to be compilable and testable in isolation, no client co
|
|||
[dependencies]
|
||||
"bytes" = "1"
|
||||
"http" = "0.2.1"
|
||||
"smithy-types" = { version = "0.0.1", path = "../smithy-types" }
|
||||
"smithy-http" = { version = "0.0.1", path = "../smithy-http" }
|
||||
"smithy-types" = { path = "../smithy-types" }
|
||||
"smithy-http" = { path = "../smithy-http" }
|
||||
"smithy-json" = { path = "../smithy-json" }
|
||||
"smithy-query" = { path = "../smithy-query" }
|
||||
"smithy-xml" = { path = "../smithy-xml" }
|
||||
|
|
|
@ -23,6 +23,8 @@ pub enum Error {
|
|||
PayloadTooLong,
|
||||
PreludeChecksumMismatch(u32, u32),
|
||||
TimestampValueTooLarge(Instant),
|
||||
Marshalling(String),
|
||||
Unmarshalling(String),
|
||||
}
|
||||
|
||||
impl StdError for Error {}
|
||||
|
@ -56,6 +58,8 @@ impl fmt::Display for Error {
|
|||
"timestamp value {:?} is too large to fit into an i64",
|
||||
time
|
||||
),
|
||||
Marshalling(error) => write!(f, "failed to marshall message: {}", error),
|
||||
Unmarshalling(error) => write!(f, "failed to unmarshall message: {}", error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ use crate::str_bytes::StrBytes;
|
|||
use bytes::{Buf, BufMut, Bytes};
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt;
|
||||
use std::mem::size_of;
|
||||
|
||||
const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
|
||||
|
@ -23,24 +24,36 @@ const MIN_HEADER_LEN: usize = 2;
|
|||
pub type SignMessageError = Box<dyn StdError + Send + Sync + 'static>;
|
||||
|
||||
/// Signs an Event Stream message.
|
||||
pub trait SignMessage {
|
||||
pub trait SignMessage: fmt::Debug {
|
||||
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError>;
|
||||
}
|
||||
|
||||
/// Converts a Smithy modeled Event Stream type into a [`Message`](Message).
|
||||
pub trait MarshallMessage {
|
||||
pub trait MarshallMessage: fmt::Debug {
|
||||
/// Smithy modeled input type to convert from.
|
||||
type Input;
|
||||
|
||||
fn marshall(&self, input: Self::Input) -> Result<Message, Error>;
|
||||
}
|
||||
|
||||
/// A successfully unmarshalled message that is either an `Event` or an `Error`.
|
||||
#[derive(Debug)]
|
||||
pub enum UnmarshalledMessage<T, E> {
|
||||
Event(T),
|
||||
Error(E),
|
||||
}
|
||||
|
||||
/// Converts an Event Stream [`Message`](Message) into a Smithy modeled type.
|
||||
pub trait UnmarshallMessage {
|
||||
pub trait UnmarshallMessage: fmt::Debug {
|
||||
/// Smithy modeled type to convert into.
|
||||
type Output;
|
||||
/// Smithy modeled error to convert into.
|
||||
type Error;
|
||||
|
||||
fn unmarshall(&self, message: Message) -> Result<Self::Output, Error>;
|
||||
fn unmarshall(
|
||||
&self,
|
||||
message: &Message,
|
||||
) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, Error>;
|
||||
}
|
||||
|
||||
mod value {
|
||||
|
@ -68,7 +81,7 @@ mod value {
|
|||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum HeaderValue {
|
||||
Bool(bool),
|
||||
Byte(u8),
|
||||
Byte(i8),
|
||||
Int16(i16),
|
||||
Int32(i32),
|
||||
Int64(i64),
|
||||
|
@ -78,6 +91,71 @@ mod value {
|
|||
Uuid(u128),
|
||||
}
|
||||
|
||||
impl HeaderValue {
|
||||
pub fn as_bool(&self) -> Result<bool, &Self> {
|
||||
match self {
|
||||
HeaderValue::Bool(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_byte(&self) -> Result<i8, &Self> {
|
||||
match self {
|
||||
HeaderValue::Byte(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_int16(&self) -> Result<i16, &Self> {
|
||||
match self {
|
||||
HeaderValue::Int16(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_int32(&self) -> Result<i32, &Self> {
|
||||
match self {
|
||||
HeaderValue::Int32(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_int64(&self) -> Result<i64, &Self> {
|
||||
match self {
|
||||
HeaderValue::Int64(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_byte_array(&self) -> Result<&Bytes, &Self> {
|
||||
match self {
|
||||
HeaderValue::ByteArray(value) => Ok(value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_string(&self) -> Result<&StrBytes, &Self> {
|
||||
match self {
|
||||
HeaderValue::String(value) => Ok(value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_timestamp(&self) -> Result<Instant, &Self> {
|
||||
match self {
|
||||
HeaderValue::Timestamp(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_uuid(&self) -> Result<u128, &Self> {
|
||||
match self {
|
||||
HeaderValue::Uuid(value) => Ok(*value),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! read_value {
|
||||
($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => {
|
||||
if $buf.remaining() >= size_of::<$size_typ>() {
|
||||
|
@ -94,7 +172,7 @@ mod value {
|
|||
match value_type {
|
||||
TYPE_TRUE => Ok(HeaderValue::Bool(true)),
|
||||
TYPE_FALSE => Ok(HeaderValue::Bool(false)),
|
||||
TYPE_BYTE => read_value!(buffer, Byte, u8, get_u8),
|
||||
TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8),
|
||||
TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16),
|
||||
TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32),
|
||||
TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64),
|
||||
|
@ -137,7 +215,7 @@ mod value {
|
|||
Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }),
|
||||
Byte(val) => {
|
||||
buffer.put_u8(TYPE_BYTE);
|
||||
buffer.put_u8(*val);
|
||||
buffer.put_i8(*val);
|
||||
}
|
||||
Int16(val) => {
|
||||
buffer.put_u8(TYPE_INT16);
|
||||
|
@ -184,7 +262,7 @@ mod value {
|
|||
Ok(match value_type {
|
||||
TYPE_TRUE => HeaderValue::Bool(true),
|
||||
TYPE_FALSE => HeaderValue::Bool(false),
|
||||
TYPE_BYTE => HeaderValue::Byte(u8::arbitrary(unstruct)?),
|
||||
TYPE_BYTE => HeaderValue::Byte(i8::arbitrary(unstruct)?),
|
||||
TYPE_INT16 => HeaderValue::Int16(i16::arbitrary(unstruct)?),
|
||||
TYPE_INT32 => HeaderValue::Int32(i32::arbitrary(unstruct)?),
|
||||
TYPE_INT64 => HeaderValue::Int64(i64::arbitrary(unstruct)?),
|
||||
|
@ -289,6 +367,14 @@ impl Message {
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a message with the given `headers` and `payload`.
|
||||
pub fn new_from_parts(headers: Vec<Header>, payload: impl Into<Bytes>) -> Self {
|
||||
Self {
|
||||
headers,
|
||||
payload: payload.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a header to the message.
|
||||
pub fn add_header(mut self, header: Header) -> Self {
|
||||
self.headers.push(header);
|
||||
|
@ -609,7 +695,7 @@ pub enum DecodedFrame {
|
|||
|
||||
/// Streaming decoder for decoding a [`Message`] from a stream.
|
||||
#[non_exhaustive]
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug)]
|
||||
pub struct MessageFrameDecoder {
|
||||
prelude: [u8; PRELUDE_LENGTH_BYTES_USIZE],
|
||||
prelude_read: bool,
|
||||
|
|
|
@ -8,4 +8,5 @@
|
|||
mod buf;
|
||||
pub mod error;
|
||||
pub mod frame;
|
||||
pub mod smithy;
|
||||
pub mod str_bytes;
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::frame::{Header, HeaderValue, Message};
|
||||
use crate::str_bytes::StrBytes;
|
||||
use smithy_types::{Blob, Instant};
|
||||
|
||||
macro_rules! expect_shape_fn {
|
||||
(fn $fn_name:ident[$val_typ:ident] -> $result_typ:ident { $val_name:ident -> $val_expr:expr }) => {
|
||||
pub fn $fn_name(header: &Header) -> Result<$result_typ, Error> {
|
||||
match header.value() {
|
||||
HeaderValue::$val_typ($val_name) => Ok($val_expr),
|
||||
_ => Err(Error::Unmarshalling(format!(
|
||||
"expected '{}' header value to be {}",
|
||||
header.name().as_str(),
|
||||
stringify!($val_typ)
|
||||
))),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
expect_shape_fn!(fn expect_bool[Bool] -> bool { value -> *value });
|
||||
expect_shape_fn!(fn expect_byte[Byte] -> i8 { value -> *value });
|
||||
expect_shape_fn!(fn expect_int16[Int16] -> i16 { value -> *value });
|
||||
expect_shape_fn!(fn expect_int32[Int32] -> i32 { value -> *value });
|
||||
expect_shape_fn!(fn expect_int64[Int64] -> i64 { value -> *value });
|
||||
expect_shape_fn!(fn expect_byte_array[ByteArray] -> Blob { bytes -> Blob::new(bytes.as_ref()) });
|
||||
expect_shape_fn!(fn expect_string[String] -> String { value -> value.as_str().into() });
|
||||
expect_shape_fn!(fn expect_timestamp[Timestamp] -> Instant { value -> *value });
|
||||
|
||||
pub struct ResponseHeaders<'a> {
|
||||
pub content_type: &'a StrBytes,
|
||||
pub message_type: &'a StrBytes,
|
||||
pub smithy_type: &'a StrBytes,
|
||||
}
|
||||
|
||||
fn expect_header_str_value<'a>(
|
||||
header: Option<&'a Header>,
|
||||
name: &str,
|
||||
) -> Result<&'a StrBytes, Error> {
|
||||
match header {
|
||||
Some(header) => Ok(header.value().as_string().map_err(|value| {
|
||||
Error::Unmarshalling(format!(
|
||||
"expected response {} header to be string, received {:?}",
|
||||
name, value
|
||||
))
|
||||
})?),
|
||||
None => Err(Error::Unmarshalling(format!(
|
||||
"expected response to include {} header, but it was missing",
|
||||
name
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_response_headers(message: &Message) -> Result<ResponseHeaders, Error> {
|
||||
let (mut content_type, mut message_type, mut event_type, mut exception_type) =
|
||||
(None, None, None, None);
|
||||
for header in message.headers() {
|
||||
match header.name().as_str() {
|
||||
":content-type" => content_type = Some(header),
|
||||
":message-type" => message_type = Some(header),
|
||||
":event-type" => event_type = Some(header),
|
||||
":exception-type" => exception_type = Some(header),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let message_type = expect_header_str_value(message_type, ":message-type")?;
|
||||
Ok(ResponseHeaders {
|
||||
content_type: expect_header_str_value(content_type, ":content-type")?,
|
||||
message_type,
|
||||
smithy_type: if message_type.as_str() == "event" {
|
||||
expect_header_str_value(event_type, ":event-type")?
|
||||
} else if message_type.as_str() == "exception" {
|
||||
expect_header_str_value(exception_type, ":exception-type")?
|
||||
} else {
|
||||
return Err(Error::Unmarshalling(format!(
|
||||
"unrecognized `:message-type`: {}",
|
||||
message_type.as_str()
|
||||
)));
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::parse_response_headers;
|
||||
use crate::frame::{Header, HeaderValue, Message};
|
||||
|
||||
#[test]
|
||||
fn normal_message() {
|
||||
let message = Message::new(&b"test"[..])
|
||||
.add_header(Header::new(
|
||||
":event-type",
|
||||
HeaderValue::String("Foo".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":content-type",
|
||||
HeaderValue::String("application/json".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":message-type",
|
||||
HeaderValue::String("event".into()),
|
||||
));
|
||||
let parsed = parse_response_headers(&message).unwrap();
|
||||
assert_eq!("Foo", parsed.smithy_type.as_str());
|
||||
assert_eq!("application/json", parsed.content_type.as_str());
|
||||
assert_eq!("event", parsed.message_type.as_str());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_message() {
|
||||
let message = Message::new(&b"test"[..])
|
||||
.add_header(Header::new(
|
||||
":exception-type",
|
||||
HeaderValue::String("BadRequestException".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":content-type",
|
||||
HeaderValue::String("application/json".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":message-type",
|
||||
HeaderValue::String("exception".into()),
|
||||
));
|
||||
let parsed = parse_response_headers(&message).unwrap();
|
||||
assert_eq!("BadRequestException", parsed.smithy_type.as_str());
|
||||
assert_eq!("application/json", parsed.content_type.as_str());
|
||||
assert_eq!("exception", parsed.message_type.as_str());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_exception_type() {
|
||||
let message = Message::new(&b"test"[..])
|
||||
.add_header(Header::new(
|
||||
":content-type",
|
||||
HeaderValue::String("application/json".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":message-type",
|
||||
HeaderValue::String("exception".into()),
|
||||
));
|
||||
let error = parse_response_headers(&message).err().unwrap().to_string();
|
||||
assert_eq!("failed to unmarshall message: expected response to include :exception-type header, but it was missing", error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_event_type() {
|
||||
let message = Message::new(&b"test"[..])
|
||||
.add_header(Header::new(
|
||||
":content-type",
|
||||
HeaderValue::String("application/json".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":message-type",
|
||||
HeaderValue::String("event".into()),
|
||||
));
|
||||
let error = parse_response_headers(&message).err().unwrap().to_string();
|
||||
assert_eq!("failed to unmarshall message: expected response to include :event-type header, but it was missing", error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_content_type() {
|
||||
let message = Message::new(&b"test"[..])
|
||||
.add_header(Header::new(
|
||||
":event-type",
|
||||
HeaderValue::String("Foo".into()),
|
||||
))
|
||||
.add_header(Header::new(
|
||||
":message-type",
|
||||
HeaderValue::String("event".into()),
|
||||
));
|
||||
let error = parse_response_headers(&message).err().unwrap().to_string();
|
||||
assert_eq!("failed to unmarshall message: expected response to include :content-type header, but it was missing", error);
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ license = "Apache-2.0"
|
|||
[features]
|
||||
bytestream-util = ["tokio/fs", "tokio-util/io"]
|
||||
event-stream = ["smithy-eventstream"]
|
||||
default = ["bytestream-util", "event-stream"]
|
||||
default = ["bytestream-util"]
|
||||
|
||||
[dependencies]
|
||||
smithy-types = { path = "../smithy-types" }
|
||||
|
|
|
@ -13,13 +13,50 @@ use futures_core::Stream;
|
|||
use hyper::body::HttpBody;
|
||||
use pin_project::pin_project;
|
||||
use smithy_eventstream::frame::{
|
||||
DecodedFrame, MarshallMessage, MessageFrameDecoder, SignMessage, UnmarshallMessage,
|
||||
DecodedFrame, MarshallMessage, Message, MessageFrameDecoder, SignMessage, UnmarshallMessage,
|
||||
UnmarshalledMessage,
|
||||
};
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
pub type BoxError = Box<dyn StdError + Send + Sync + 'static>;
|
||||
|
||||
/// Input type for Event Streams.
|
||||
pub struct EventStreamInput<T> {
|
||||
input_stream: Pin<Box<dyn Stream<Item = Result<T, BoxError>> + Send>>,
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for EventStreamInput<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "EventStreamInput(Box<dyn Stream>)")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EventStreamInput<T> {
|
||||
#[doc(hidden)]
|
||||
pub fn into_body_stream<E: StdError + Send + Sync + 'static>(
|
||||
self,
|
||||
marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
|
||||
signer: impl SignMessage + Send + Sync + 'static,
|
||||
) -> MessageStreamAdapter<T, E> {
|
||||
MessageStreamAdapter::new(marshaller, signer, self.input_stream)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> From<S> for EventStreamInput<T>
|
||||
where
|
||||
S: Stream<Item = Result<T, BoxError>> + Send + 'static,
|
||||
{
|
||||
fn from(stream: S) -> Self {
|
||||
EventStreamInput {
|
||||
input_stream: Box::pin(stream),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
|
||||
/// message marshaller and signer implementations.
|
||||
///
|
||||
|
@ -30,24 +67,32 @@ pub struct MessageStreamAdapter<T, E> {
|
|||
marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
|
||||
signer: Box<dyn SignMessage + Send + Sync>,
|
||||
#[pin]
|
||||
stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
|
||||
stream: Pin<Box<dyn Stream<Item = Result<T, BoxError>> + Send>>,
|
||||
_phantom: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
|
||||
impl<T, E> MessageStreamAdapter<T, E>
|
||||
where
|
||||
E: StdError + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(
|
||||
marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
|
||||
signer: impl SignMessage + Send + Sync + 'static,
|
||||
stream: impl Stream<Item = Result<T, E>> + Send + Sync + 'static,
|
||||
stream: Pin<Box<dyn Stream<Item = Result<T, BoxError>> + Send>>,
|
||||
) -> Self {
|
||||
MessageStreamAdapter {
|
||||
marshaller: Box::new(marshaller),
|
||||
signer: Box::new(signer),
|
||||
stream: Box::pin(stream),
|
||||
stream,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
|
||||
impl<T, E> Stream for MessageStreamAdapter<T, E>
|
||||
where
|
||||
E: StdError + Send + Sync + 'static,
|
||||
{
|
||||
type Item = Result<Bytes, SdkError<E>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
|
@ -56,7 +101,7 @@ impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T,
|
|||
Poll::Ready(message_option) => {
|
||||
if let Some(message_result) = message_option {
|
||||
let message_result =
|
||||
message_result.map_err(|err| SdkError::ConstructionFailure(Box::new(err)));
|
||||
message_result.map_err(|err| SdkError::ConstructionFailure(err));
|
||||
let message = this
|
||||
.marshaller
|
||||
.marshall(message_result?)
|
||||
|
@ -80,17 +125,21 @@ impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T,
|
|||
}
|
||||
|
||||
/// Receives Smithy-modeled messages out of an Event Stream.
|
||||
pub struct Receiver<T, E: StdError + Send + Sync> {
|
||||
unmarshaller: Box<dyn UnmarshallMessage<Output = T>>,
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver<T, E> {
|
||||
unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E>>,
|
||||
decoder: MessageFrameDecoder,
|
||||
buffer: SegmentedBuf<Bytes>,
|
||||
body: SdkBody,
|
||||
_phantom: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<T, E: StdError + Send + Sync> Receiver<T, E> {
|
||||
impl<T, E> Receiver<T, E> {
|
||||
/// Creates a new `Receiver` with the given message unmarshaller and SDK body.
|
||||
pub fn new(unmarshaller: impl UnmarshallMessage<Output = T> + 'static, body: SdkBody) -> Self {
|
||||
pub fn new(
|
||||
unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + 'static,
|
||||
body: SdkBody,
|
||||
) -> Self {
|
||||
Receiver {
|
||||
unmarshaller: Box::new(unmarshaller),
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
|
@ -104,7 +153,7 @@ impl<T, E: StdError + Send + Sync> Receiver<T, E> {
|
|||
/// it returns an `Ok(None)`. If there is a transport layer error, it will return
|
||||
/// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
|
||||
/// messages.
|
||||
pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E>> {
|
||||
pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, Message>> {
|
||||
let next_chunk = self
|
||||
.body
|
||||
.data()
|
||||
|
@ -119,11 +168,16 @@ impl<T, E: StdError + Send + Sync> Receiver<T, E> {
|
|||
.decode_frame(&mut self.buffer)
|
||||
.map_err(|err| SdkError::DispatchFailure(Box::new(err)))?
|
||||
{
|
||||
return Ok(Some(
|
||||
self.unmarshaller
|
||||
.unmarshall(message)
|
||||
.map_err(|err| SdkError::DispatchFailure(Box::new(err)))?,
|
||||
));
|
||||
return match self
|
||||
.unmarshaller
|
||||
.unmarshall(&message)
|
||||
.map_err(|err| SdkError::DispatchFailure(Box::new(err)))?
|
||||
{
|
||||
UnmarshalledMessage::Event(event) => Ok(Some(event)),
|
||||
UnmarshalledMessage::Error(err) => {
|
||||
Err(SdkError::ServiceError { err, raw: message })
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
|
@ -134,7 +188,7 @@ impl<T, E: StdError + Send + Sync> Receiver<T, E> {
|
|||
mod tests {
|
||||
use super::{MarshallMessage, Receiver, UnmarshallMessage};
|
||||
use crate::body::SdkBody;
|
||||
use crate::event_stream::MessageStreamAdapter;
|
||||
use crate::event_stream::{EventStreamInput, MessageStreamAdapter};
|
||||
use crate::result::SdkError;
|
||||
use async_stream::stream;
|
||||
use bytes::Bytes;
|
||||
|
@ -142,7 +196,9 @@ mod tests {
|
|||
use futures_util::stream::StreamExt;
|
||||
use hyper::body::Body;
|
||||
use smithy_eventstream::error::Error as EventStreamError;
|
||||
use smithy_eventstream::frame::{Header, HeaderValue, Message, SignMessage, SignMessageError};
|
||||
use smithy_eventstream::frame::{
|
||||
Header, HeaderValue, Message, SignMessage, SignMessageError, UnmarshalledMessage,
|
||||
};
|
||||
use std::error::Error as StdError;
|
||||
use std::io::{Error as IOError, ErrorKind};
|
||||
|
||||
|
@ -164,25 +220,31 @@ mod tests {
|
|||
impl StdError for FakeError {}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct UnmarshalledMessage(String);
|
||||
struct TestMessage(String);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Marshaller;
|
||||
impl MarshallMessage for Marshaller {
|
||||
type Input = UnmarshalledMessage;
|
||||
type Input = TestMessage;
|
||||
|
||||
fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
|
||||
Ok(Message::new(input.0.as_bytes().to_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Unmarshaller;
|
||||
impl UnmarshallMessage for Unmarshaller {
|
||||
type Output = UnmarshalledMessage;
|
||||
type Output = TestMessage;
|
||||
type Error = EventStreamError;
|
||||
|
||||
fn unmarshall(&self, message: Message) -> Result<Self::Output, EventStreamError> {
|
||||
Ok(UnmarshalledMessage(
|
||||
fn unmarshall(
|
||||
&self,
|
||||
message: &Message,
|
||||
) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
|
||||
Ok(UnmarshalledMessage::Event(TestMessage(
|
||||
std::str::from_utf8(&message.payload()[..]).unwrap().into(),
|
||||
))
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,14 +254,13 @@ mod tests {
|
|||
vec![Ok(encode_message("one")), Ok(encode_message("two"))];
|
||||
let chunk_stream = futures_util::stream::iter(chunks);
|
||||
let body = SdkBody::from(Body::wrap_stream(chunk_stream));
|
||||
let mut receiver =
|
||||
Receiver::<UnmarshalledMessage, EventStreamError>::new(Unmarshaller, body);
|
||||
let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
|
||||
assert_eq!(
|
||||
UnmarshalledMessage("one".into()),
|
||||
TestMessage("one".into()),
|
||||
receiver.recv().await.unwrap().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
UnmarshalledMessage("two".into()),
|
||||
TestMessage("two".into()),
|
||||
receiver.recv().await.unwrap().unwrap()
|
||||
);
|
||||
}
|
||||
|
@ -212,10 +273,9 @@ mod tests {
|
|||
];
|
||||
let chunk_stream = futures_util::stream::iter(chunks);
|
||||
let body = SdkBody::from(Body::wrap_stream(chunk_stream));
|
||||
let mut receiver =
|
||||
Receiver::<UnmarshalledMessage, EventStreamError>::new(Unmarshaller, body);
|
||||
let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
|
||||
assert_eq!(
|
||||
UnmarshalledMessage("one".into()),
|
||||
TestMessage("one".into()),
|
||||
receiver.recv().await.unwrap().unwrap()
|
||||
);
|
||||
assert!(matches!(
|
||||
|
@ -234,10 +294,9 @@ mod tests {
|
|||
];
|
||||
let chunk_stream = futures_util::stream::iter(chunks);
|
||||
let body = SdkBody::from(Body::wrap_stream(chunk_stream));
|
||||
let mut receiver =
|
||||
Receiver::<UnmarshalledMessage, EventStreamError>::new(Unmarshaller, body);
|
||||
let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
|
||||
assert_eq!(
|
||||
UnmarshalledMessage("one".into()),
|
||||
TestMessage("one".into()),
|
||||
receiver.recv().await.unwrap().unwrap()
|
||||
);
|
||||
assert!(matches!(
|
||||
|
@ -246,6 +305,16 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestServiceError;
|
||||
impl std::fmt::Display for TestServiceError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "TestServiceError")
|
||||
}
|
||||
}
|
||||
impl StdError for TestServiceError {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestSigner;
|
||||
impl SignMessage for TestSigner {
|
||||
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
|
||||
|
@ -267,12 +336,15 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn message_stream_adapter_success() {
|
||||
let stream = stream! {
|
||||
yield Ok(UnmarshalledMessage("test".into()));
|
||||
yield Ok(TestMessage("test".into()));
|
||||
};
|
||||
let mut adapter =
|
||||
check_compatible_with_hyper_wrap_stream(
|
||||
MessageStreamAdapter::<_, EventStreamError>::new(Marshaller, TestSigner, stream),
|
||||
);
|
||||
check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
|
||||
TestMessage,
|
||||
TestServiceError,
|
||||
>::new(
|
||||
Marshaller, TestSigner, Box::pin(stream)
|
||||
));
|
||||
|
||||
let mut sent_bytes = adapter.next().await.unwrap().unwrap();
|
||||
let sent = Message::read_from(&mut sent_bytes).unwrap();
|
||||
|
@ -285,12 +357,15 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn message_stream_adapter_construction_failure() {
|
||||
let stream = stream! {
|
||||
yield Err(EventStreamError::InvalidMessageLength);
|
||||
yield Err(EventStreamError::InvalidMessageLength.into());
|
||||
};
|
||||
let mut adapter =
|
||||
check_compatible_with_hyper_wrap_stream(
|
||||
MessageStreamAdapter::<UnmarshalledMessage, _>::new(Marshaller, TestSigner, stream),
|
||||
);
|
||||
check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
|
||||
TestMessage,
|
||||
TestServiceError,
|
||||
>::new(
|
||||
Marshaller, TestSigner, Box::pin(stream)
|
||||
));
|
||||
|
||||
let result = adapter.next().await.unwrap();
|
||||
assert!(result.is_err());
|
||||
|
@ -299,4 +374,18 @@ mod tests {
|
|||
SdkError::ConstructionFailure(_)
|
||||
));
|
||||
}
|
||||
|
||||
// Verify the developer experience for this compiles
|
||||
#[allow(unused)]
|
||||
fn event_stream_input_ergonomics() {
|
||||
fn check(input: impl Into<EventStreamInput<TestMessage>>) {
|
||||
let _: EventStreamInput<TestMessage> = input.into();
|
||||
}
|
||||
check(stream! {
|
||||
yield Ok(TestMessage("test".into()));
|
||||
});
|
||||
check(stream! {
|
||||
yield Err(EventStreamError::InvalidMessageLength.into());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -159,6 +159,11 @@ impl Request {
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a new operation `Request` from its parts.
|
||||
pub fn from_parts(inner: http::Request<SdkBody>, properties: Arc<Mutex<PropertyBag>>) -> Self {
|
||||
Request { inner, properties }
|
||||
}
|
||||
|
||||
/// Allows modification of the HTTP request and associated properties with a fallible closure.
|
||||
pub fn augment<T>(
|
||||
self,
|
||||
|
|
|
@ -19,7 +19,7 @@ pub struct SdkSuccess<O> {
|
|||
|
||||
/// Failed SDK Result
|
||||
#[derive(Debug)]
|
||||
pub enum SdkError<E> {
|
||||
pub enum SdkError<E, R = operation::Response> {
|
||||
/// The request failed during construction. It was not dispatched over the network.
|
||||
ConstructionFailure(BoxError),
|
||||
|
||||
|
@ -35,10 +35,10 @@ pub enum SdkError<E> {
|
|||
},
|
||||
|
||||
/// An error response was received from the service
|
||||
ServiceError { err: E, raw: operation::Response },
|
||||
ServiceError { err: E, raw: R },
|
||||
}
|
||||
|
||||
impl<E> Display for SdkError<E>
|
||||
impl<E, R> Display for SdkError<E, R>
|
||||
where
|
||||
E: Error,
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue