Add incomplete Event Stream support with working Amazon Transcribe example (#653)

* Add incomplete Event Stream support with working Amazon Transcribe example

* Make the raw response in SdkError generic

* Fix XmlBindingTraitSerializerGeneratorTest

* Make the build aware of the SMITHYRS_EXPERIMENTAL_EVENTSTREAM switch

* Fix SigV4SigningCustomizationTest

* Update changelog

* Fix build when SMITHYRS_EXPERIMENTAL_EVENTSTREAM is not set

* Add initial unit test for EventStreamUnmarshallerGenerator

* Add event header unmarshalling support

* Don't pull in event stream dependencies by default

* Only add event stream signer to config for services that need it

* Move event stream inlineables into smithy-eventstream

* Fix some clippy lints

* Transform event stream unions

* Fix crash in SigV4SigningDecorator

* Add test for unmarshalling errors

* Incorporate CR feedback
This commit is contained in:
John DiSanti 2021-08-20 10:50:42 -07:00 committed by GitHub
parent b119782a65
commit 3b8f69c18d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 2389 additions and 245 deletions

View File

@ -3,6 +3,7 @@ vNext (Month Day, Year)
**New this week**
- (When complete) Add Event Stream support (#653, #xyz)
- (When complete) Add profile file provider for region (#594, #xyz)
v0.21 (August 19th, 2021)

View File

@ -7,12 +7,17 @@ license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
sign-eventstream = ["smithy-eventstream", "aws-sigv4/sign-eventstream"]
default = []
[dependencies]
http = "0.2.2"
aws-sigv4 = { path = "../aws-sigv4" }
aws-auth = { path = "../aws-auth" }
aws-types = { path = "../aws-types" }
smithy-http = { path = "../../../rust-runtime/smithy-http" }
smithy-eventstream = { path = "../../../rust-runtime/smithy-eventstream", optional = true }
# Trying this out as an experiment. thiserror can be removed and replaced with hand written error
# implementations and it is not a breaking change.
thiserror = "1"

View File

@ -0,0 +1,129 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
use crate::middleware::Signature;
use aws_auth::Credentials;
use aws_sigv4::event_stream::sign_message;
use aws_sigv4::SigningParams;
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use smithy_eventstream::frame::{Message, SignMessage, SignMessageError};
use smithy_http::property_bag::PropertyBag;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::SystemTime;
/// Event Stream SigV4 signing implementation.
#[derive(Debug)]
pub struct SigV4Signer {
properties: Arc<Mutex<PropertyBag>>,
last_signature: Option<String>,
}
impl SigV4Signer {
pub fn new(properties: Arc<Mutex<PropertyBag>>) -> Self {
Self {
properties,
last_signature: None,
}
}
}
impl SignMessage for SigV4Signer {
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
let properties = PropertyAccessor(self.properties.lock().unwrap());
if self.last_signature.is_none() {
// The Signature property should exist in the property bag for all Event Stream requests.
self.last_signature = Some(properties.expect::<Signature>().as_ref().into())
}
// Every single one of these values would have been retrieved during the initial request,
// so we can safely assume they all exist in the property bag at this point.
let credentials = properties.expect::<Credentials>();
let region = properties.expect::<SigningRegion>();
let signing_service = properties.expect::<SigningService>();
let time = properties
.get::<SystemTime>()
.copied()
.unwrap_or_else(SystemTime::now);
let params = SigningParams {
access_key: credentials.access_key_id(),
secret_key: credentials.secret_access_key(),
security_token: credentials.session_token(),
region: region.as_ref(),
service_name: signing_service.as_ref(),
date_time: time.into(),
settings: (),
};
let (signed_message, signature) =
sign_message(&message, self.last_signature.as_ref().unwrap(), &params).into_parts();
self.last_signature = Some(signature);
Ok(signed_message)
}
}
// TODO(EventStream): Make a new type around `Arc<Mutex<PropertyBag>>` called `SharedPropertyBag`
// and abstract the mutex away entirely.
struct PropertyAccessor<'a>(MutexGuard<'a, PropertyBag>);
impl<'a> PropertyAccessor<'a> {
fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.0.get::<T>()
}
fn expect<T: Send + Sync + 'static>(&self) -> &T {
self.get::<T>()
.expect("property should have been inserted into property bag via middleware")
}
}
#[cfg(test)]
mod tests {
use crate::event_stream::SigV4Signer;
use crate::middleware::Signature;
use aws_auth::Credentials;
use aws_types::region::Region;
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use smithy_eventstream::frame::{HeaderValue, Message, SignMessage};
use smithy_http::property_bag::PropertyBag;
use std::sync::{Arc, Mutex};
use std::time::{Duration, UNIX_EPOCH};
#[test]
fn sign_message() {
let region = Region::new("us-east-1");
let mut properties = PropertyBag::new();
properties.insert(region.clone());
properties.insert(UNIX_EPOCH + Duration::new(1611160427, 0));
properties.insert(SigningService::from_static("transcribe"));
properties.insert(Credentials::from_keys("AKIAfoo", "bar", None));
properties.insert(SigningRegion::from(region));
properties.insert(Signature::new("initial-signature".into()));
let mut signer = SigV4Signer::new(Arc::new(Mutex::new(properties)));
let mut signatures = Vec::new();
for _ in 0..5 {
let signed = signer
.sign(Message::new(&b"identical message"[..]))
.unwrap();
if let HeaderValue::ByteArray(signature) = signed
.headers()
.iter()
.find(|h| h.name().as_str() == ":chunk-signature")
.unwrap()
.value()
{
signatures.push(signature.clone());
} else {
panic!("failed to get the :chunk-signature")
}
}
for i in 1..signatures.len() {
assert_ne!(signatures[i - 1], signatures[i]);
}
}
}

View File

@ -7,5 +7,8 @@
//!
//! In the future, additional signature algorithms can be enabled as Cargo Features.
#[cfg(feature = "sign-eventstream")]
pub mod event_stream;
pub mod middleware;
pub mod signer;

View File

@ -24,8 +24,10 @@ impl Signature {
pub fn new(signature: String) -> Self {
Self(signature)
}
}
pub fn as_str(&self) -> &str {
impl AsRef<str> for Signature {
fn as_ref(&self) -> &str {
&self.0
}
}

View File

@ -10,7 +10,7 @@ description = "AWS SigV4 signer"
[features]
sign-http = ["http", "http-body", "percent-encoding", "form_urlencoded"]
sign-eventstream = ["smithy-eventstream"]
default = ["sign-http", "sign-eventstream"]
default = ["sign-http"]
[dependencies]
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }

View File

@ -32,5 +32,5 @@ object AwsRuntimeType {
val S3Errors by lazy { RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("s3_errors")) }
}
fun RuntimeConfig.awsRuntimeDependency(name: String, features: List<String> = listOf()): CargoDependency =
fun RuntimeConfig.awsRuntimeDependency(name: String, features: Set<String> = setOf()): CargoDependency =
CargoDependency(name, awsRoot().crateLocation(), features = features)

View File

@ -49,7 +49,8 @@ class IntegrationTestDependencies(
override fun section(section: LibRsSection) = when (section) {
LibRsSection.Body -> writable {
if (hasTests) {
val smithyClient = CargoDependency.SmithyClient(runtimeConfig).copy(features = listOf("test-util"), scope = DependencyScope.Dev)
val smithyClient = CargoDependency.SmithyClient(runtimeConfig)
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
addDependency(smithyClient)
addDependency(SerdeJson)
addDependency(Tokio)
@ -63,5 +64,5 @@ class IntegrationTestDependencies(
}
val Criterion = CargoDependency("criterion", CratesIo("0.3"), scope = DependencyScope.Dev)
val SerdeJson = CargoDependency("serde_json", CratesIo("1"), features = emptyList(), scope = DependencyScope.Dev)
val Tokio = CargoDependency("tokio", CratesIo("1"), features = listOf("macros", "test-util"), scope = DependencyScope.Dev)
val SerdeJson = CargoDependency("serde_json", CratesIo("1"), features = emptySet(), scope = DependencyScope.Dev)
val Tokio = CargoDependency("tokio", CratesIo("1"), features = setOf("macros", "test-util"), scope = DependencyScope.Dev)

View File

@ -13,12 +13,14 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
@ -28,11 +30,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfi
import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.hasEventStreamOperations
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.isInputEventStream
/**
* The SigV4SigningDecorator:
* - adds a `signing_service()` method to `config` to return the default signing service
* - adds a `new_event_stream_signer()` method to `config` to create an Event Stream SigV4 signer
* - sets the `SigningService` during operation construction
* - sets a default `OperationSigningConfig` A future enhancement will customize this for specific services that need
* different behavior.
@ -47,8 +52,12 @@ class SigV4SigningDecorator : RustCodegenDecorator {
protocolConfig: ProtocolConfig,
baseCustomizations: List<ConfigCustomization>
): List<ConfigCustomization> {
return baseCustomizations.letIf(applies(protocolConfig)) {
it + SigV4SigningConfig(protocolConfig.serviceShape.expectTrait())
return baseCustomizations.letIf(applies(protocolConfig)) { customizations ->
customizations + SigV4SigningConfig(
protocolConfig.runtimeConfig,
protocolConfig.serviceShape.hasEventStreamOperations(protocolConfig.model),
protocolConfig.serviceShape.expectTrait()
)
}
}
@ -58,26 +67,63 @@ class SigV4SigningDecorator : RustCodegenDecorator {
baseCustomizations: List<OperationCustomization>
): List<OperationCustomization> {
return baseCustomizations.letIf(applies(protocolConfig)) {
it + SigV4SigningFeature(operation, protocolConfig.runtimeConfig, protocolConfig.serviceShape, protocolConfig.model)
it + SigV4SigningFeature(
protocolConfig.model,
operation,
protocolConfig.runtimeConfig,
protocolConfig.serviceShape,
)
}
}
}
class SigV4SigningConfig(private val sigV4Trait: SigV4Trait) : ConfigCustomization() {
class SigV4SigningConfig(
runtimeConfig: RuntimeConfig,
private val serviceHasEventStream: Boolean,
private val sigV4Trait: SigV4Trait
) : ConfigCustomization() {
private val codegenScope = arrayOf(
"SigV4Signer" to RuntimeType(
"SigV4Signer",
runtimeConfig.awsRuntimeDependency("aws-sig-auth", setOf("sign-eventstream")),
"aws_sig_auth::event_stream"
),
"PropertyBag" to RuntimeType(
"PropertyBag",
CargoDependency.SmithyHttp(runtimeConfig),
"smithy_http::property_bag"
)
)
override fun section(section: ServiceConfig): Writable {
return when (section) {
is ServiceConfig.ConfigImpl -> writable {
rust(
rustTemplate(
"""
/// The signature version 4 service signing name to use in the credential scope when signing requests.
///
/// The signing service may be overidden by the `Endpoint`, or by specifying a custom [`SigningService`](aws_types::SigningService) during
/// operation construction
/// The signing service may be overridden by the `Endpoint`, or by specifying a custom
/// [`SigningService`](aws_types::SigningService) during operation construction
pub fn signing_service(&self) -> &'static str {
${sigV4Trait.name.dq()}
}
"""
""",
*codegenScope
)
if (serviceHasEventStream) {
rustTemplate(
"""
/// Creates a new Event Stream `SignMessage` implementor.
pub fn new_event_stream_signer(
&self,
properties: std::sync::Arc<std::sync::Mutex<#{PropertyBag}>>
) -> #{SigV4Signer} {
#{SigV4Signer}::new(properties)
}
""",
*codegenScope
)
}
}
else -> emptySection
}
@ -95,10 +141,10 @@ fun disableDoubleEncode(service: ServiceShape) = when {
}
class SigV4SigningFeature(
private val model: Model,
private val operation: OperationShape,
runtimeConfig: RuntimeConfig,
private val service: ServiceShape,
model: Model
) :
OperationCustomization() {
private val codegenScope =
@ -111,9 +157,9 @@ class SigV4SigningFeature(
is OperationSection.MutateRequest -> writable {
rustTemplate(
"""
##[allow(unused_mut)]
let mut signing_config = #{sig_auth}::signer::OperationSigningConfig::default_config();
""",
##[allow(unused_mut)]
let mut signing_config = #{sig_auth}::signer::OperationSigningConfig::default_config();
""",
*codegenScope
)
if (needsAmzSha256(service)) {
@ -128,6 +174,12 @@ class SigV4SigningFeature(
"${section.request}.properties_mut().insert(#{sig_auth}::signer::SignableBody::UnsignedPayload);",
*codegenScope
)
} else if (operation.isInputEventStream(model)) {
// TODO(EventStream): Is this actually correct for all Event Stream operations?
rustTemplate(
"${section.request}.properties_mut().insert(#{sig_auth}::signer::SignableBody::Bytes(&[]));",
*codegenScope
)
}
// some operations are either unsigned or optionally signed:
val authSchemes = serviceIndex.getEffectiveAuthSchemes(service, operation)
@ -140,9 +192,9 @@ class SigV4SigningFeature(
}
rustTemplate(
"""
${section.request}.properties_mut().insert(signing_config);
${section.request}.properties_mut().insert(#{aws_types}::SigningService::from_static(${section.config}.signing_service()));
""",
${section.request}.properties_mut().insert(signing_config);
${section.request}.properties_mut().insert(#{aws_types}::SigningService::from_static(${section.config}.signing_service()));
""",
*codegenScope
)
}

View File

@ -14,13 +14,19 @@ import software.amazon.smithy.rust.codegen.testutil.unitTest
internal class SigV4SigningCustomizationTest {
@Test
fun `generates a valid config`() {
val project = stubConfigProject(SigV4SigningConfig(SigV4Trait.builder().name("test-service").build()))
val project = stubConfigProject(
SigV4SigningConfig(
AwsTestRuntimeConfig,
true,
SigV4Trait.builder().name("test-service").build()
)
)
project.lib {
it.unitTest(
"""
let conf = crate::config::Config::builder().build();
assert_eq!(conf.signing_service(), "test-service");
"""
let conf = crate::config::Config::builder().build();
assert_eq!(conf.signing_service(), "test-service");
"""
)
}
project.compileAndTest()

View File

@ -235,6 +235,8 @@ task("generateSmithyBuild") {
projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(awsServices))
}
inputs.property("servicelist", awsServices.sortedBy { it.module }.toString())
// TODO(EventStream): Remove this when removing SMITHYRS_EXPERIMENTAL_EVENTSTREAM
inputs.property("_eventStreamCacheInvalidation", System.getenv("SMITHYRS_EXPERIMENTAL_EVENTSTREAM") ?: "0")
inputs.dir(projectDir.resolve("aws-models"))
outputs.file(projectDir.resolve("smithy-build.json"))
}

View File

@ -0,0 +1,18 @@
[package]
name = "transcribestreaming"
version = "0.1.0"
authors = ["John DiSanti <jdisanti@amazon.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
aws-auth-providers = { path = "../../build/aws-sdk/aws-auth-providers" }
aws-sdk-transcribestreaming = { package = "aws-sdk-transcribestreaming", path = "../../build/aws-sdk/transcribestreaming" }
aws-types = { path = "../../build/aws-sdk/aws-types" }
async-stream = "0.3"
bytes = "1"
hound = "3.4"
tokio = { version = "1", features = ["full"] }
tracing-subscriber = "0.2.18"

View File

@ -0,0 +1,66 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
use async_stream::stream;
use aws_sdk_transcribestreaming::model::{AudioEvent, AudioStream, LanguageCode, MediaEncoding};
use aws_sdk_transcribestreaming::{Blob, Client, Config, Region};
use bytes::BufMut;
use std::time::Duration;
const CHUNK_SIZE: usize = 8192;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let input_stream = stream! {
let pcm = pcm_data();
for chunk in pcm.chunks(CHUNK_SIZE) {
// Sleeping isn't necessary, but emphasizes the streaming aspect of this
tokio::time::sleep(Duration::from_millis(100)).await;
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
}
// Must send an empty chunk at the end
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(Vec::new())).build()));
};
let config = Config::builder()
.region(Region::from_static("us-west-2"))
.build();
let client = Client::from_conf(config);
let mut output = client
.start_stream_transcription()
.language_code(LanguageCode::EnGb)
.media_sample_rate_hertz(8000)
.media_encoding(MediaEncoding::Pcm)
.audio_stream(input_stream.into())
.send()
.await
.unwrap();
loop {
match output.transcript_result_stream.recv().await {
Ok(Some(transcription)) => {
println!("Received transcription response:\n{:?}\n", transcription)
}
Ok(None) => break,
Err(err) => println!("Received an error: {:?}", err),
}
}
println!("Done.")
}
fn pcm_data() -> Vec<u8> {
let audio = include_bytes!("../audio/hello-transcribe-8000.wav");
let reader = hound::WavReader::new(&audio[..]).unwrap();
let samples_result: hound::Result<Vec<i16>> = reader.into_samples::<i16>().collect();
let mut pcm: Vec<u8> = Vec::new();
for sample in samples_result.unwrap() {
pcm.put_i16_le(sample);
}
pcm
}

View File

@ -106,6 +106,8 @@ task("generateSmithyBuild") {
doFirst {
projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(CodegenTests))
}
// TODO(EventStream): Remove this when removing SMITHYRS_EXPERIMENTAL_EVENTSTREAM
inputs.property("_eventStreamCacheInvalidation", System.getenv("SMITHYRS_EXPERIMENTAL_EVENTSTREAM") ?: "0")
}

View File

@ -88,6 +88,9 @@ class InlineDependency(
private fun forRustFile(name: String, vararg additionalDependencies: RustDependency) =
forRustFile(name, "inlineable", *additionalDependencies)
fun eventStream(runtimeConfig: RuntimeConfig) =
forRustFile("event_stream", CargoDependency.SmithyEventStream(runtimeConfig))
fun jsonErrors(runtimeConfig: RuntimeConfig) =
forRustFile("json_errors", CargoDependency.Http, CargoDependency.SmithyTypes(runtimeConfig))
@ -118,8 +121,15 @@ data class CargoDependency(
private val location: DependencyLocation,
val scope: DependencyScope = DependencyScope.Compile,
val optional: Boolean = false,
private val features: List<String> = listOf()
val features: Set<String> = emptySet()
) : RustDependency(name) {
val key: Triple<String, DependencyLocation, DependencyScope> get() = Triple(name, location, scope)
fun canMergeWith(other: CargoDependency): Boolean = key == other.key
fun withFeature(feature: String): CargoDependency {
return copy(features = features.toMutableSet().apply { add(feature) })
}
override fun version(): String = when (location) {
is CratesIo -> location.version
@ -173,12 +183,15 @@ data class CargoDependency(
val Md5 = CargoDependency("md5", CratesIo("0.7"))
val FastRand = CargoDependency("fastrand", CratesIo("1"))
val Http: CargoDependency = CargoDependency("http", CratesIo("0.2"))
val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14"))
val HyperWithStream: CargoDependency = Hyper.withFeature("stream")
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"), optional = true)
fun SmithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("types")
fun SmithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("client")
fun SmithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("eventstream")
fun SmithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http")
fun SmithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-tower")
fun SmithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("client")
fun ProtocolTestHelpers(runtimeConfig: RuntimeConfig) = CargoDependency(
"protocol-test-helpers", runtimeConfig.runtimeCrateLocation.crateLocation(), scope = DependencyScope.Dev

View File

@ -101,8 +101,10 @@ fun CodegenWriterDelegator<RustWriter>.finalize(
this.useFileWriter("src/lib.rs", "crate::lib") { writer ->
LibRsGenerator(settings.moduleDescription, modules, libRsCustomizations).render(writer)
}
val cargoDependencies =
this.dependencies.map { RustDependency.fromSymbolDependency(it) }.filterIsInstance<CargoDependency>().distinct()
val cargoDependencies = mergeDependencyFeatures(
this.dependencies.map { RustDependency.fromSymbolDependency(it) }
.filterIsInstance<CargoDependency>().distinct()
)
this.useFileWriter("Cargo.toml") {
val cargoToml = CargoTomlGenerator(
settings,
@ -114,3 +116,18 @@ fun CodegenWriterDelegator<RustWriter>.finalize(
}
flushWriters()
}
private fun CargoDependency.mergeWith(other: CargoDependency): CargoDependency {
check(key == other.key)
return copy(
features = features + other.features,
optional = optional && other.optional
)
}
internal fun mergeDependencyFeatures(cargoDependencies: List<CargoDependency>): List<CargoDependency> =
cargoDependencies.groupBy { it.key }
.mapValues { group -> group.value.reduce { acc, next -> acc.mergeWith(next) } }
.values
.toList()
.sortedBy { it.name }

View File

@ -28,7 +28,10 @@ import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.transformers.AddErrorMessage
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
@ -78,6 +81,9 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
private fun baselineTransform(model: Model) =
model.let(RecursiveShapeBoxer::transform)
.letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform)
.let(OperationNormalizer::transform)
.let(RemoveEventStreamOperations::transform)
.let(EventStreamNormalizer::transform)
fun execute() {
logger.info("generating Rust client...")

View File

@ -0,0 +1,65 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.render
import software.amazon.smithy.rust.codegen.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.isEventStream
import software.amazon.smithy.rust.codegen.util.isInputEventStream
/**
* Wrapping symbol provider to wrap modeled types with the smithy-http Event Stream send/receive types.
*/
class EventStreamSymbolProvider(
private val runtimeConfig: RuntimeConfig,
base: RustSymbolProvider,
private val model: Model
) : WrappingSymbolProvider(base) {
override fun toSymbol(shape: Shape): Symbol {
val initial = super.toSymbol(shape)
// We only want to wrap with Event Stream types when dealing with member shapes
if (shape is MemberShape && shape.isEventStream(model)) {
// Determine if the member has a container that is a synthetic input or output
val operationShape = model.expectShape(shape.container).let { maybeInputOutput ->
val operationId = maybeInputOutput.getTrait<SyntheticInputTrait>()?.operation
?: maybeInputOutput.getTrait<SyntheticOutputTrait>()?.operation
operationId?.let { model.expectShape(it, OperationShape::class.java) }
}
// If we find an operation shape, then we can wrap the type
if (operationShape != null) {
val error = operationShape.errorSymbol(this).toSymbol()
val errorFmt = error.rustType().render(fullyQualified = true)
val innerFmt = initial.rustType().stripOuter<RustType.Option>().render(fullyQualified = true)
val outer = when (shape.isInputEventStream(model)) {
true -> "EventStreamInput<$innerFmt>"
else -> "Receiver<$innerFmt, $errorFmt>"
}
val rustType = RustType.Opaque(outer, "smithy_http::event_stream")
return initial.toBuilder()
.name(rustType.name)
.rustType(rustType)
.addReference(error)
.addReference(initial)
.addDependency(CargoDependency.SmithyHttp(runtimeConfig).withFeature("event-stream"))
.build()
}
}
return initial
}
}

View File

@ -161,6 +161,15 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
val HttpRequestBuilder = Http("request::Builder")
val HttpResponseBuilder = Http("response::Builder")
val Hyper = CargoDependency.Hyper.asType()
fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType =
RuntimeType(
"Receiver",
dependency = CargoDependency.SmithyHttp(runtimeConfig),
"smithy_http::event_stream"
)
fun jsonErrors(runtimeConfig: RuntimeConfig) =
forInlineDependency(InlineDependency.jsonErrors(runtimeConfig))

View File

@ -28,6 +28,7 @@ class RustCodegenPlugin : SmithyBuildPlugin {
companion object {
fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig) =
SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig)
.let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) }
.let { StreamingShapeSymbolProvider(it, model) }
.let { BaseSymbolMetadataProvider(it) }
.let { StreamingShapeMetadataProvider(it, model) }

View File

@ -27,7 +27,10 @@ val ClippyAllowLints = listOf(
"should_implement_trait",
// protocol tests use silly names like `baz`, don't flag that
"blacklisted_name"
"blacklisted_name",
// Forcing use of `vec![]` can make codegen harder in some cases
"vec_init_then_push",
)
class AllowClippyLints : LibRsCustomization() {

View File

@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.smithy.generators
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
@ -51,13 +52,14 @@ class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
fun MemberShape.setterName(): String = "set_${this.memberName.toSnakeCase()}"
class BuilderGenerator(
val model: Model,
private val model: Model,
private val symbolProvider: RustSymbolProvider,
private val shape: StructureShape
) {
private val members: List<MemberShape> = shape.allMembers.values.toList()
private val runtimeConfig = symbolProvider.config().runtimeConfig
private val members: List<MemberShape> = shape.allMembers.values.toList()
private val structureSymbol = symbolProvider.toSymbol(shape)
fun render(writer: RustWriter) {
val symbol = symbolProvider.toSymbol(shape)
// TODO: figure out exactly what docs we want on a the builder module
@ -104,6 +106,53 @@ class BuilderGenerator(
}
}
// TODO(EventStream): [DX] Consider updating builders to take EventInputStream as Into<EventInputStream>
private fun renderBuilderMember(writer: RustWriter, member: MemberShape, memberName: String, memberSymbol: Symbol) {
// builder members are crate-public to enable using them
// directly in serializers/deserializers
writer.write("pub(crate) $memberName: #T,", memberSymbol)
}
private fun renderBuilderMemberFn(
writer: RustWriter,
coreType: RustType,
member: MemberShape,
memberName: String,
memberSymbol: Symbol
) {
fun builderConverter(coreType: RustType) = when (coreType) {
is RustType.String,
is RustType.Box -> "input.into()"
else -> "input"
}
val signature = when (coreType) {
is RustType.String,
is RustType.Box -> "(mut self, input: impl Into<${coreType.render(true)}>) -> Self"
else -> "(mut self, input: ${coreType.render(true)}) -> Self"
}
writer.documentShape(member, model)
writer.rustBlock("pub fn $memberName$signature") {
write("self.$memberName = Some(${builderConverter(coreType)});")
write("self")
}
}
private fun renderBuilderMemberSetterFn(
writer: RustWriter,
outerType: RustType,
member: MemberShape,
memberName: String,
memberSymbol: Symbol
) {
// Render a `set_foo` method. This is useful as a target for code generation, because the argument type
// is the same as the resulting member type, and is always optional.
val inputType = outerType.asOptional()
writer.rustBlock("pub fn ${member.setterName()}(mut self, input: ${inputType.render(true)}) -> Self") {
rust("self.$memberName = input; self")
}
}
private fun renderBuilder(writer: RustWriter) {
val builderName = "Builder"
@ -119,18 +168,10 @@ class BuilderGenerator(
val memberName = symbolProvider.toMemberName(member)
// All fields in the builder are optional
val memberSymbol = symbolProvider.toSymbol(member).makeOptional()
// builder members are crate-public to enable using them
// directly in serializers/deserializers
write("pub(crate) $memberName: #T,", memberSymbol)
renderBuilderMember(this, member, memberName, memberSymbol)
}
}
fun builderConverter(coreType: RustType) = when (coreType) {
is RustType.String,
is RustType.Box -> "input.into()"
else -> "input"
}
writer.rustBlock("impl $builderName") {
members.forEach { member ->
// All fields in the builder are optional
@ -143,26 +184,10 @@ class BuilderGenerator(
when (coreType) {
is RustType.Vec -> renderVecHelper(memberName, coreType)
is RustType.HashMap -> renderMapHelper(memberName, coreType)
else -> {
val signature = when (coreType) {
is RustType.String,
is RustType.Box -> "(mut self, input: impl Into<${coreType.render(true)}>) -> Self"
else -> "(mut self, input: ${coreType.render(true)}) -> Self"
}
writer.documentShape(member, model)
writer.rustBlock("pub fn $memberName$signature") {
write("self.$memberName = Some(${builderConverter(coreType)});")
write("self")
}
}
else -> renderBuilderMemberFn(this, coreType, member, memberName, memberSymbol)
}
// Render a `set_foo` method. This is useful as a target for code generation, because the argument type
// is the same as the resulting member type, and is always optional.
val inputType = outerType.asOptional()
writer.rustBlock("pub fn ${member.setterName()}(mut self, input: ${inputType.render(true)}) -> Self") {
rust("self.$memberName = input; self")
}
renderBuilderMemberSetterFn(this, outerType, member, memberName, memberSymbol)
}
buildFn(this)
}

View File

@ -195,6 +195,7 @@ abstract class HttpProtocolGenerator(
) {
withBlock("Ok({", "})") {
features.forEach { it.section(OperationSection.MutateInput("self", "_config"))(this) }
rust("let properties = std::sync::Arc::new(std::sync::Mutex::new(smithy_http::property_bag::PropertyBag::new()));")
rust("let request = self.request_builder_base()?;")
withBlock("let body =", ";") {
body("self", shape)
@ -203,7 +204,7 @@ abstract class HttpProtocolGenerator(
rust(
"""
##[allow(unused_mut)]
let mut request = #T::Request::new(request.map(#T::from));
let mut request = #T::Request::from_parts(request.map(#T::from), properties);
""",
operationModule, sdkBody
)

View File

@ -146,7 +146,7 @@ class HttpProtocolTestGenerator(
val Tokio = CargoDependency(
"tokio",
CratesIo("1"),
features = listOf("macros", "test-util", "rt"),
features = setOf("macros", "test-util", "rt"),
scope = DependencyScope.Dev
)
testModuleWriter.addDependency(Tokio)

View File

@ -24,12 +24,12 @@ class UnionGenerator(
private val writer: RustWriter,
private val shape: UnionShape
) {
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
fun render() {
renderUnion()
}
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
private fun renderUnion() {
val unionSymbol = symbolProvider.toSymbol(shape)
val containerMeta = unionSymbol.expectRustMetadata()

View File

@ -116,12 +116,11 @@ class ServiceConfigGenerator(private val customizations: List<ConfigCustomizatio
writer.rustBlock("impl std::fmt::Debug for Config") {
rustTemplate(
"""
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut config = f.debug_struct("Config");
config.finish()
}
"""
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut config = f.debug_struct("Config");
config.finish()
}
"""
)
}
@ -129,7 +128,7 @@ class ServiceConfigGenerator(private val customizations: List<ConfigCustomizatio
rustTemplate(
"""
pub fn builder() -> Builder { Builder::default() }
"""
"""
)
customizations.forEach {
it.section(ServiceConfig.ConfigImpl)(this)

View File

@ -35,6 +35,8 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.makeOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.EventStreamUnmarshallerGenerator
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasTrait
@ -42,7 +44,11 @@ import software.amazon.smithy.rust.codegen.util.isPrimitive
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.toSnakeCase
class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val operationShape: OperationShape) {
class ResponseBindingGenerator(
private val protocol: Protocol,
protocolConfig: ProtocolConfig,
private val operationShape: OperationShape
) {
private val runtimeConfig = protocolConfig.runtimeConfig
private val symbolProvider = protocolConfig.symbolProvider
private val model = protocolConfig.model
@ -124,6 +130,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
* Generate a function to deserialize `[binding]` from the response payload
*/
fun generateDeserializePayloadFn(
operationShape: OperationShape,
binding: HttpBindingDescriptor,
errorT: RuntimeType,
// Deserialize a single structure or union member marked as a payload
@ -142,7 +149,13 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
outputT,
errorT
) {
deserializeStreamingBody(binding)
// Streaming unions are Event Streams and should be handled separately
val target = model.expectShape(binding.member.target)
if (target is UnionShape) {
bindEventStreamOutput(operationShape, target)
} else {
deserializeStreamingBody(binding)
}
}
} else {
rustWriter.rustBlock("pub fn $fnName(body: &[u8]) -> std::result::Result<#T, #T>", outputT, errorT) {
@ -157,6 +170,27 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
}
}
private fun RustWriter.bindEventStreamOutput(operationShape: OperationShape, target: UnionShape) {
val unmarshallerConstructorFn = EventStreamUnmarshallerGenerator(
protocol,
model,
runtimeConfig,
symbolProvider,
operationShape,
target
).render()
rustTemplate(
"""
let unmarshaller = #{unmarshallerConstructorFn}();
let body = std::mem::replace(body, #{SdkBody}::taken());
Ok(#{Receiver}::new(unmarshaller, body))
""",
"SdkBody" to RuntimeType.sdkBody(runtimeConfig),
"unmarshallerConstructorFn" to unmarshallerConstructorFn,
"Receiver" to RuntimeType.eventStreamReceiver(runtimeConfig),
)
}
private fun RustWriter.deserializeStreamingBody(binding: HttpBindingDescriptor) {
val member = binding.member
val targetShape = model.expectShape(member.target)
@ -164,10 +198,10 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
rustTemplate(
"""
// replace the body with an empty body
let body = std::mem::replace(body, #{sdk_body}::taken());
Ok(#{byte_stream}::new(body))
let body = std::mem::replace(body, #{SdkBody}::taken());
Ok(#{ByteStream}::new(body))
""",
"byte_stream" to RuntimeType.byteStream(runtimeConfig), "sdk_body" to RuntimeType.sdkBody(runtimeConfig)
"ByteStream" to RuntimeType.byteStream(runtimeConfig), "SdkBody" to RuntimeType.sdkBody(runtimeConfig)
)
}

View File

@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
@ -24,9 +23,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGene
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.orNull
@ -47,17 +43,7 @@ class AwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFac
return HttpBoundProtocolGenerator(protocolConfig, AwsJson(protocolConfig, version))
}
private val shapeIfHasMembers: StructureModifier = { _, shape: StructureShape? ->
when (shape?.members().isNullOrEmpty()) {
true -> null
else -> shape
}
}
override fun transformModel(model: Model): Model {
// For AwsJson10, the body matches 1:1 with the input
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
}
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport = ProtocolSupport(
requestSerialization = true,

View File

@ -24,17 +24,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.AwsQueryParser
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.AwsQuerySerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.util.getTrait
class AwsQueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator =
HttpBoundProtocolGenerator(protocolConfig, AwsQueryProtocol(protocolConfig))
override fun transformModel(model: Model): Model {
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
}
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(

View File

@ -22,16 +22,12 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.Ec2QueryParser
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.Ec2QuerySerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
class Ec2QueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator =
HttpBoundProtocolGenerator(protocolConfig, Ec2QueryProtocol(protocolConfig))
override fun transformModel(model: Model): Model {
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
}
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(

View File

@ -40,6 +40,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.EventStreamMarshallerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember
import software.amazon.smithy.rust.codegen.util.dq
@ -47,6 +48,8 @@ import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isEventStream
import software.amazon.smithy.rust.codegen.util.isInputEventStream
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
@ -81,10 +84,12 @@ class HttpBoundProtocolGenerator(
"ParseStrict" to RuntimeType.parseStrict(runtimeConfig),
"ParseResponse" to RuntimeType.parseResponse(runtimeConfig),
"http" to RuntimeType.http,
"hyper" to CargoDependency.HyperWithStream.asType(),
"operation" to RuntimeType.operationModule(runtimeConfig),
"Bytes" to RuntimeType.Bytes,
"SdkBody" to RuntimeType.sdkBody(runtimeConfig),
"BuildError" to runtimeConfig.operationBuildError()
"BuildError" to runtimeConfig.operationBuildError(),
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType()
)
override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata {
@ -103,10 +108,50 @@ class HttpBoundProtocolGenerator(
BodyMetadata(takesOwnership = false)
} else {
val member = inputShape.expectMember(payloadMemberName)
serializeViaPayload(member, serializerGenerator)
if (operationShape.isInputEventStream(model)) {
serializeViaEventStream(operationShape, member, serializerGenerator)
} else {
serializeViaPayload(member, serializerGenerator)
}
}
}
private fun RustWriter.serializeViaEventStream(
operationShape: OperationShape,
memberShape: MemberShape,
serializerGenerator: StructuredDataSerializerGenerator
): BodyMetadata {
val memberName = symbolProvider.toMemberName(memberShape)
val unionShape = model.expectShape(memberShape.target, UnionShape::class.java)
val marshallerConstructorFn = EventStreamMarshallerGenerator(
model,
runtimeConfig,
symbolProvider,
unionShape,
serializerGenerator
).render()
// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
// parameters that are not `@eventHeader` or `@eventPayload`.
rustTemplate(
"""
{
let marshaller = #{marshallerConstructorFn}();
let signer = _config.new_event_stream_signer(properties.clone());
let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> =
self.$memberName.into_body_stream(marshaller, signer);
let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
body
}
""",
*codegenScope,
"marshallerConstructorFn" to marshallerConstructorFn,
"OperationError" to operationShape.errorSymbol(symbolProvider)
)
return BodyMetadata(takesOwnership = true)
}
private fun RustWriter.serializeViaPayload(
member: MemberShape,
serializerGenerator: StructuredDataSerializerGenerator
@ -175,6 +220,10 @@ class HttpBoundProtocolGenerator(
BodyMetadata(takesOwnership = true)
}
is StructureShape, is UnionShape -> {
check(
!((targetShape as? UnionShape)?.isEventStream() ?: false)
) { "Event Streams should be handled further up" }
// JSON serialize the structure or union targeted
rust(
"""#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?""",
@ -430,7 +479,7 @@ class HttpBoundProtocolGenerator(
bindings: List<HttpBindingDescriptor>,
errorSymbol: RuntimeType,
) {
val httpBindingGenerator = ResponseBindingGenerator(protocolConfig, operationShape)
val httpBindingGenerator = ResponseBindingGenerator(protocol, protocolConfig, operationShape)
val structuredDataParser = protocol.structuredDataParser(operationShape)
Attribute.AllowUnusedMut.render(this)
rust("let mut output = #T::default();", outputShape.builderSymbol(symbolProvider))
@ -464,7 +513,7 @@ class HttpBoundProtocolGenerator(
}
val err = if (StructureGenerator.fallibleBuilder(outputShape, symbolProvider)) {
".map_err(|s|${format(errorSymbol)}::unhandled(s))?"
".map_err(${format(errorSymbol)}::unhandled)?"
} else ""
rust("output.build()$err")
}
@ -509,6 +558,7 @@ class HttpBoundProtocolGenerator(
rust("#T($body).map_err(#T::unhandled)", structuredDataParser.payloadParser(member), errorSymbol)
}
val deserializer = httpBindingGenerator.generateDeserializePayloadFn(
operationShape,
binding,
errorSymbol,
docHandler = docShapeHandler,

View File

@ -20,17 +20,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGene
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
override fun buildProtocolGenerator(
protocolConfig: ProtocolConfig
): HttpBoundProtocolGenerator = HttpBoundProtocolGenerator(protocolConfig, RestJson(protocolConfig))
override fun transformModel(model: Model): Model {
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
}
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(

View File

@ -21,8 +21,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.RestXmlParserG
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.XmlBindingTraitSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.util.expectTrait
class RestXmlFactory(private val generator: (ProtocolConfig) -> Protocol = { RestXml(it) }) :
@ -33,9 +31,7 @@ class RestXmlFactory(private val generator: (ProtocolConfig) -> Protocol = { Res
return HttpBoundProtocolGenerator(protocolConfig, generator(protocolConfig))
}
override fun transformModel(model: Model): Model {
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
}
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(

View File

@ -0,0 +1,288 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.protocols.parse
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.ByteShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.LongShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShortShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EventHeaderTrait
import software.amazon.smithy.model.traits.EventPayloadTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.toPascalCase
class EventStreamUnmarshallerGenerator(
private val protocol: Protocol,
private val model: Model,
runtimeConfig: RuntimeConfig,
private val symbolProvider: RustSymbolProvider,
private val operationShape: OperationShape,
private val unionShape: UnionShape,
) {
private val unionSymbol = symbolProvider.toSymbol(unionShape)
private val operationErrorSymbol = operationShape.errorSymbol(symbolProvider)
private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig)
private val codegenScope = arrayOf(
"Blob" to RuntimeType("Blob", CargoDependency.SmithyTypes(runtimeConfig), "smithy_types"),
"Error" to RuntimeType("Error", smithyEventStream, "smithy_eventstream::error"),
"Header" to RuntimeType("Header", smithyEventStream, "smithy_eventstream::frame"),
"HeaderValue" to RuntimeType("HeaderValue", smithyEventStream, "smithy_eventstream::frame"),
"expect_fns" to RuntimeType("smithy", smithyEventStream, "smithy_eventstream"),
"Message" to RuntimeType("Message", smithyEventStream, "smithy_eventstream::frame"),
"SmithyError" to RuntimeType("Error", CargoDependency.SmithyTypes(runtimeConfig), "smithy_types"),
"UnmarshallMessage" to RuntimeType("UnmarshallMessage", smithyEventStream, "smithy_eventstream::frame"),
"UnmarshalledMessage" to RuntimeType("UnmarshalledMessage", smithyEventStream, "smithy_eventstream::frame"),
)
fun render(): RuntimeType {
val unmarshallerType = unionShape.eventStreamUnmarshallerType()
return RuntimeType.forInlineFun("${unmarshallerType.name}::new", "event_stream_serde") { inlineWriter ->
inlineWriter.renderUnmarshaller(unmarshallerType, unionSymbol)
}
}
private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) {
rust(
"""
##[non_exhaustive]
##[derive(Debug)]
pub struct ${unmarshallerType.name};
impl ${unmarshallerType.name} {
pub fn new() -> Self {
${unmarshallerType.name}
}
}
"""
)
rustBlockTemplate(
"impl #{UnmarshallMessage} for ${unmarshallerType.name}",
*codegenScope
) {
rust("type Output = #T;", unionSymbol)
rust("type Error = #T;", operationErrorSymbol)
rustBlockTemplate(
"""
fn unmarshall(
&self,
message: &#{Message}
) -> std::result::Result<#{UnmarshalledMessage}<Self::Output, Self::Error>, #{Error}>
""",
*codegenScope
) {
rustBlockTemplate(
"""
let response_headers = #{expect_fns}::parse_response_headers(&message)?;
match response_headers.message_type.as_str()
""",
*codegenScope
) {
rustBlock("\"event\" => ") {
renderUnmarshallEvent()
}
rustBlock("\"exception\" => ") {
renderUnmarshallError()
}
rustBlock("value => ") {
rustTemplate(
"return Err(#{Error}::Unmarshalling(format!(\"unrecognized :message-type: {}\", value)));",
*codegenScope
)
}
}
}
}
}
private fun RustWriter.renderUnmarshallEvent() {
rustBlock("match response_headers.smithy_type.as_str()") {
for (member in unionShape.members()) {
val target = model.expectShape(member.target, StructureShape::class.java)
rustBlock("${member.memberName.dq()} => ") {
renderUnmarshallUnionMember(member, target)
}
}
rustBlock("smithy_type => ") {
// TODO: Handle this better once unions support unknown variants
rustTemplate(
"return Err(#{Error}::Unmarshalling(format!(\"unrecognized :event-type: {}\", smithy_type)));",
*codegenScope
)
}
}
}
private fun RustWriter.renderUnmarshallUnionMember(unionMember: MemberShape, unionStruct: StructureShape) {
val unionMemberName = unionMember.memberName.toPascalCase()
val payloadOnly =
unionStruct.members().none { it.hasTrait<EventPayloadTrait>() || it.hasTrait<EventHeaderTrait>() }
if (payloadOnly) {
withBlock("let parsed = ", ";") {
renderParseProtocolPayload(unionMember)
}
rustTemplate(
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(parsed)))",
"Output" to unionSymbol,
*codegenScope
)
} else {
rust("let mut builder = #T::builder();", symbolProvider.toSymbol(unionStruct))
val payloadMember = unionStruct.members().firstOrNull { it.hasTrait<EventPayloadTrait>() }
if (payloadMember != null) {
renderUnmarshallEventPayload(payloadMember)
}
val headerMembers = unionStruct.members().filter { it.hasTrait<EventHeaderTrait>() }
if (headerMembers.isNotEmpty()) {
rustBlock("for header in message.headers()") {
rustBlock("match header.name().as_str()") {
for (member in headerMembers) {
rustBlock("${member.memberName.dq()} => ") {
renderUnmarshallEventHeader(member)
}
}
rust("_ => {}")
}
}
}
rustTemplate(
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(builder.build())))",
"Output" to unionSymbol,
*codegenScope
)
}
}
private fun RustWriter.renderUnmarshallEventHeader(member: MemberShape) {
val memberName = symbolProvider.toMemberName(member)
withBlock("builder = builder.$memberName(", ");") {
when (val target = model.expectShape(member.target)) {
is BooleanShape -> rustTemplate("#{expect_fns}::expect_bool(header)?", *codegenScope)
is ByteShape -> rustTemplate("#{expect_fns}::expect_byte(header)?", *codegenScope)
is ShortShape -> rustTemplate("#{expect_fns}::expect_int16(header)?", *codegenScope)
is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope)
is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope)
is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope)
is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope)
is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope)
else -> throw IllegalStateException("unsupported event stream header shape type: $target")
}
}
}
private fun RustWriter.renderUnmarshallEventPayload(member: MemberShape) {
// TODO(EventStream): [RPC] Don't blow up on an initial-message that's not part of the union (:event-type will be "initial-request" or "initial-response")
// TODO(EventStream): [RPC] Incorporate initial-message into original output (:event-type will be "initial-request" or "initial-response")
val memberName = symbolProvider.toMemberName(member)
withBlock("builder = builder.$memberName(", ");") {
when (model.expectShape(member.target)) {
is BlobShape -> {
rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope)
}
is StringShape -> {
rustTemplate(
"""
std::str::from_utf8(message.payload())
.map_err(|_| #{Error}::Unmarshalling("message payload is not valid UTF-8".into()))?
""",
*codegenScope
)
}
is UnionShape, is StructureShape -> {
renderParseProtocolPayload(member)
}
}
}
}
private fun RustWriter.renderParseProtocolPayload(member: MemberShape) {
// TODO(EventStream): Check :content-type against expected content-type, error if unexpected
val parser = protocol.structuredDataParser(operationShape).payloadParser(member)
val memberName = member.memberName.toPascalCase()
rustTemplate(
"""
#{parser}(&message.payload()[..])
.map_err(|err| {
#{Error}::Unmarshalling(format!("failed to unmarshall $memberName: {}", err))
})?
""",
"parser" to parser,
*codegenScope
)
}
private fun RustWriter.renderUnmarshallError() {
val syntheticUnion = unionShape.expectTrait<SyntheticEventStreamUnionTrait>()
if (syntheticUnion.errorMembers.isNotEmpty()) {
rustBlock("match response_headers.smithy_type.as_str()") {
for (member in syntheticUnion.errorMembers) {
val target = model.expectShape(member.target, StructureShape::class.java)
rustBlock("${member.memberName.dq()} => ") {
val parser = protocol.structuredDataParser(operationShape).errorParser(target)
if (parser != null) {
rust("let mut builder = #T::builder();", symbolProvider.toSymbol(target))
// TODO(EventStream): Errors on the operation can be disjoint with errors in the union,
// so we need to generate a new top-level Error type for each event stream union.
rustTemplate(
"""
builder = #{parser}(&message.payload()[..], builder)
.map_err(|err| {
#{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err))
})?;
return Ok(#{UnmarshalledMessage}::Error(
#{OpError}::new(
#{OpError}Kind::${member.memberName.toPascalCase()}(builder.build()),
#{SmithyError}::builder().build(),
)
))
""",
"OpError" to operationErrorSymbol,
"parser" to parser,
*codegenScope
)
}
}
}
rust("_ => {}")
}
}
// TODO(EventStream): Generic error parsing; will need to refactor `parseGenericError` to
// operate on bodies rather than responses. This should be easy for all but restJson,
// which pulls the error type out of a header.
rust("unimplemented!(\"event stream generic error parsing\")")
}
private fun UnionShape.eventStreamUnmarshallerType(): RuntimeType {
val symbol = symbolProvider.toSymbol(this)
return RuntimeType("${symbol.name.toPascalCase()}Unmarshaller", null, "crate::event_stream_serde")
}
}

View File

@ -0,0 +1,161 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.protocols.serialize
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EventHeaderTrait
import software.amazon.smithy.model.traits.EventPayloadTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.render
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.toPascalCase
// TODO(EventStream): [TEST] Unit test EventStreamMarshallerGenerator
class EventStreamMarshallerGenerator(
private val model: Model,
runtimeConfig: RuntimeConfig,
private val symbolProvider: RustSymbolProvider,
private val unionShape: UnionShape,
private val serializerGenerator: StructuredDataSerializerGenerator,
) {
private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig)
private val codegenScope = arrayOf(
"MarshallMessage" to RuntimeType("MarshallMessage", smithyEventStream, "smithy_eventstream::frame"),
"Message" to RuntimeType("Message", smithyEventStream, "smithy_eventstream::frame"),
"Header" to RuntimeType("Header", smithyEventStream, "smithy_eventstream::frame"),
"HeaderValue" to RuntimeType("HeaderValue", smithyEventStream, "smithy_eventstream::frame"),
"Error" to RuntimeType("Error", smithyEventStream, "smithy_eventstream::error"),
)
fun render(): RuntimeType {
val marshallerType = unionShape.eventStreamMarshallerType()
val unionSymbol = symbolProvider.toSymbol(unionShape)
return RuntimeType.forInlineFun("${marshallerType.name}::new", "event_stream_serde") { inlineWriter ->
inlineWriter.renderMarshaller(marshallerType, unionSymbol)
}
}
private fun RustWriter.renderMarshaller(marshallerType: RuntimeType, unionSymbol: Symbol) {
rust(
"""
##[non_exhaustive]
##[derive(Debug)]
pub struct ${marshallerType.name};
impl ${marshallerType.name} {
pub fn new() -> Self {
${marshallerType.name}
}
}
"""
)
rustBlockTemplate(
"impl #{MarshallMessage} for ${marshallerType.name}",
*codegenScope
) {
rust("type Input = ${unionSymbol.rustType().render(fullyQualified = true)};")
rustBlockTemplate(
"fn marshall(&self, input: Self::Input) -> std::result::Result<#{Message}, #{Error}>",
*codegenScope
) {
rust("let mut headers = Vec::new();")
addStringHeader(":message-type", "\"event\".into()")
rustBlock("let payload = match input") {
for (member in unionShape.members()) {
val eventType = member.memberName // must be the original name, not the Rust-safe name
rustBlock("Self::Input::${member.memberName.toPascalCase()}(inner) => ") {
addStringHeader(":event-type", "${eventType.dq()}.into()")
val target = model.expectShape(member.target, StructureShape::class.java)
serializeEvent(target)
}
}
}
rustTemplate("; Ok(#{Message}::new_from_parts(headers, payload))", *codegenScope)
}
}
}
private fun RustWriter.serializeEvent(struct: StructureShape) {
for (member in struct.members()) {
val memberName = symbolProvider.toMemberName(member)
val target = model.expectShape(member.target)
if (member.hasTrait<EventPayloadTrait>()) {
serializeUnionMember(memberName, member, target)
} else if (member.hasTrait<EventHeaderTrait>()) {
TODO("TODO(EventStream): Implement @eventHeader trait")
} else {
throw IllegalStateException("Event Stream members must be a header or payload")
}
}
}
private fun RustWriter.serializeUnionMember(memberName: String, member: MemberShape, target: Shape) {
if (target is BlobShape || target is StringShape) {
data class PayloadContext(val conversionFn: String, val contentType: String)
val ctx = when (target) {
is BlobShape -> PayloadContext("into_inner", "application/octet-stream")
is StringShape -> PayloadContext("into_bytes", "text/plain")
else -> throw IllegalStateException("unreachable")
}
addStringHeader(":content-type", "${ctx.contentType.dq()}.into()")
if (member.isOptional) {
rust(
"""
if let Some(inner_payload) = inner.$memberName {
inner_payload.${ctx.conversionFn}()
} else {
Vec::new()
}
"""
)
} else {
rust("inner.$memberName.${ctx.conversionFn}()")
}
} else {
// TODO(EventStream): Select content-type based on protocol
addStringHeader(":content-type", "\"TODO\".into()")
val serializerFn = serializerGenerator.payloadSerializer(member)
rustTemplate(
"""
#{serializerFn}(&inner.$memberName)
.map_err(|err| #{Error}::Marshalling(format!("{}", err)))?
""",
"serializerFn" to serializerFn,
*codegenScope
)
}
}
private fun RustWriter.addStringHeader(name: String, valueExpr: String) {
rustTemplate("headers.push(#{Header}::new(${name.dq()}, #{HeaderValue}::String($valueExpr)));", *codegenScope)
}
private fun UnionShape.eventStreamMarshallerType(): RuntimeType {
val symbol = symbolProvider.toSymbol(this)
return RuntimeType("${symbol.name.toPascalCase()}Marshaller", null, "crate::event_stream_serde")
}
}

View File

@ -141,7 +141,7 @@ class JsonSerializerGenerator(
val target = model.expectShape(member.target, StructureShape::class.java)
return RuntimeType.forInlineFun(fnName, "operation_ser") { writer ->
writer.rustBlockTemplate(
"pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>",
"pub fn $fnName(input: &#{target}) -> std::result::Result<std::vec::Vec<u8>, #{Error}>",
*codegenScope,
"target" to symbolProvider.toSymbol(target)
) {
@ -149,7 +149,7 @@ class JsonSerializerGenerator(
rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope)
serializeStructure(StructContext("object", "input", target))
rust("object.finish();")
rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope)
rustTemplate("Ok(out.into_bytes())", *codegenScope)
}
}
}

View File

@ -101,6 +101,8 @@ abstract class QuerySerializerGenerator(protocolConfig: ProtocolConfig) : Struct
}
override fun payloadSerializer(member: MemberShape): RuntimeType {
// TODO(EventStream): [RPC] The query will need to be rendered to the initial message,
// so this needs to be implemented
TODO("The $protocolName protocol doesn't support http payload traits")
}

View File

@ -11,30 +11,30 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
interface StructuredDataSerializerGenerator {
/**
* Generate a parse function for a given targeted as a payload.
* Entry point for payload-based parsing.
* Roughly:
* Generate a serializer for a request payload. Expected signature:
* ```rust
* fn serialize_some_payload(input: &PayloadSmithyType) -> Result<Vec<u8>, Error> {
* ...
* }
* ```
*/
fun payloadSerializer(member: MemberShape): RuntimeType
/** Generate a serializer for operation input
* Because only a subset of fields of the operation may be impacted by the document, a builder is passed
* through:
*
/**
* Generate a serializer for an operation input.
* ```rust
* fn parse_some_operation(inp: &[u8], builder: my_operation::Builder) -> Result<my_operation::Builder, XmlError> {
* ...
* fn serialize_some_operation(input: &SomeSmithyType) -> Result<SdkBody, Error> {
* ...
* }
* ```
*/
fun operationSerializer(operationShape: OperationShape): RuntimeType?
/**
* Generate a serializer for a document.
* ```rust
* fn parse_document(inp: &[u8]) -> Result<Document, Error> {
* ...
* fn serialize_document(input: &Document) -> Result<SdkBody, Error> {
* ...
* }
* ```
*/

View File

@ -121,7 +121,7 @@ class XmlBindingTraitSerializerGenerator(
let mut writer = #{XmlWriter}::new(&mut out);
##[allow(unused_mut)]
let mut root = writer.start_el(${operationXmlName.dq()})${inputShape.xmlNamespace().apply()};
""",
""",
*codegenScope
)
serializeStructure(inputShape, xmlMembers, Ctx.Element("root", "&input"))
@ -140,10 +140,9 @@ class XmlBindingTraitSerializerGenerator(
val target = model.expectShape(member.target, StructureShape::class.java)
return RuntimeType.forInlineFun(fnName, "xml_ser") {
val t = symbolProvider.toSymbol(member).rustType().stripOuter<RustType.Option>().render(true)
it.rustBlock(
"pub fn $fnName(input: &$t) -> Result<#T, String>",
RuntimeType.sdkBody(runtimeConfig),
it.rustBlockTemplate(
"pub fn $fnName(input: &$t) -> std::result::Result<std::vec::Vec<u8>, String>",
*codegenScope
) {
rust("let mut out = String::new();")
// create a scope for writer. This ensure that writer has been dropped before returning the
@ -156,7 +155,7 @@ class XmlBindingTraitSerializerGenerator(
let mut root = writer.start_el(${xmlIndex.payloadShapeName(member).dq()})${
target.xmlNamespace().apply()
};
""",
""",
*codegenScope
)
serializeStructure(
@ -165,7 +164,7 @@ class XmlBindingTraitSerializerGenerator(
Ctx.Element("root", "&input")
)
}
rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope)
rustTemplate("Ok(out.into_bytes())", *codegenScope)
}
}
}

View File

@ -0,0 +1,19 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.traits
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait
class SyntheticEventStreamUnionTrait(
val errorMembers: List<MemberShape>,
) : AnnotationTrait(ID, Node.objectNode()) {
companion object {
val ID = ShapeId.from("smithy.api.internal#syntheticEventStreamUnion")
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.transformers
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.isEventStream
/**
* Generates synthetic unions to replace the modeled unions for Event Stream types.
* This allows us to strip out all the error union members once up-front, instead of in each
* place that does codegen with the unions.
*/
object EventStreamNormalizer {
fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape ->
if (shape is UnionShape && shape.isEventStream()) {
syntheticEquivalent(model, shape)
} else {
shape
}
}
private fun syntheticEquivalent(model: Model, union: UnionShape): UnionShape {
val (errorMembers, eventMembers) = union.members().partition { member ->
model.expectShape(member.target).hasTrait<ErrorTrait>()
}
return union.toBuilder()
.members(eventMembers)
.addTrait(SyntheticEventStreamUnionTrait(errorMembers))
.build()
}
}

View File

@ -17,23 +17,25 @@ import software.amazon.smithy.rust.codegen.util.orNull
import java.util.Optional
import kotlin.streams.toList
typealias StructureModifier = (OperationShape, StructureShape?) -> StructureShape?
/**
* Generate synthetic Input and Output structures for operations.
*/
class OperationNormalizer(private val model: Model) {
object OperationNormalizer {
// Functions to construct synthetic shape IDs—Don't rely on these in external code.
// Rename safety: Operations cannot be renamed
private fun OperationShape.syntheticInputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Input")
private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Output")
/**
* Add synthetic input & output shapes to every Operation in model. The generated shapes will be marked with
* [SyntheticInputTrait] and [SyntheticOutputTrait] respectively. Shapes will be added _even_ if the operation does
* not specify an input or an output.
*/
fun transformModel(): Model {
fun transform(model: Model): Model {
val transformer = ModelTransformer.create()
val operations = model.shapes(OperationShape::class.java).toList()
val newShapes = operations.flatMap { operation ->
// Generate or modify the input and output of the given `Operation` to be a unique shape
syntheticInputShapes(operation) + syntheticOutputShapes(operation)
syntheticInputShapes(model, operation) + syntheticOutputShapes(model, operation)
}
val modelWithOperationInputs = model.toBuilder().addShapes(newShapes).build()
return transformer.mapShapes(modelWithOperationInputs) {
@ -49,7 +51,7 @@ class OperationNormalizer(private val model: Model) {
}
}
private fun syntheticOutputShapes(operation: OperationShape): List<StructureShape> {
private fun syntheticOutputShapes(model: Model, operation: OperationShape): List<StructureShape> {
val outputId = operation.syntheticOutputId()
val outputShapeBuilder = operation.output.map { shapeId ->
model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(outputId)
@ -63,7 +65,7 @@ class OperationNormalizer(private val model: Model) {
return listOfNotNull(outputShape)
}
private fun syntheticInputShapes(operation: OperationShape): List<StructureShape> {
private fun syntheticInputShapes(model: Model, operation: OperationShape): List<StructureShape> {
val inputId = operation.syntheticInputId()
val inputShapeBuilder = operation.input.map { shapeId ->
model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(inputId)
@ -79,13 +81,6 @@ class OperationNormalizer(private val model: Model) {
}
private fun empty(id: ShapeId) = StructureShape.builder().id(id)
companion object {
// Functions to construct synthetic shape IDs—Don't rely on these in external code.
// Rename safety: Operations cannot be renamed
private fun OperationShape.syntheticInputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Input")
private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Output")
}
}
private fun StructureShape.Builder.rename(newId: ShapeId): StructureShape.Builder {

View File

@ -13,23 +13,38 @@ import software.amazon.smithy.rust.codegen.util.findStreamingMember
import software.amazon.smithy.rust.codegen.util.orNull
import java.util.logging.Logger
// TODO(EventStream): Remove this class once the Event Stream implementation is stable
/** Transformer to REMOVE operations that use EventStreaming until event streaming is supported */
object RemoveEventStreamOperations {
private val logger = Logger.getLogger(javaClass.name)
fun transform(model: Model): Model = ModelTransformer.create().filterShapes(model) { parentShape ->
if (parentShape !is OperationShape) {
true
} else {
val ioShapes = listOfNotNull(parentShape.output.orNull(), parentShape.input.orNull()).map { model.expectShape(it, StructureShape::class.java) }
val hasEventStream = ioShapes.any { ioShape ->
val streamingMember = ioShape.findStreamingMember(model)?.let { model.expectShape(it.target) }
streamingMember?.isUnionShape ?: false
}
// If a streaming member has a union trait, it is an event stream. Event Streams are not currently supported
// by the SDK, so if we generate this API it won't work.
(!hasEventStream).also {
if (!it) {
logger.info("Removed $parentShape from model because it targets an event stream")
private fun eventStreamEnabled(): Boolean =
System.getenv()["SMITHYRS_EXPERIMENTAL_EVENTSTREAM"] == "1"
fun transform(model: Model): Model {
if (eventStreamEnabled()) {
return model
}
return ModelTransformer.create().filterShapes(model) { parentShape ->
if (parentShape !is OperationShape) {
true
} else {
val ioShapes = listOfNotNull(parentShape.output.orNull(), parentShape.input.orNull()).map {
model.expectShape(
it,
StructureShape::class.java
)
}
val hasEventStream = ioShapes.any { ioShape ->
val streamingMember = ioShape.findStreamingMember(model)?.let { model.expectShape(it.target) }
streamingMember?.isUnionShape ?: false
}
// If a streaming member has a union trait, it is an event stream. Event Streams are not currently supported
// by the SDK, so if we generate this API it won't work.
(!hasEventStream).also {
if (!it) {
logger.info("Removed $parentShape from model because it targets an event stream")
}
}
}
}

View File

@ -157,7 +157,12 @@ class TestWriterDelegator(fileManifest: FileManifest, symbolProvider: RustSymbol
val baseDir: Path = fileManifest.baseDir
}
fun TestWriterDelegator.compileAndTest() {
/**
* Setting `runClippy` to true can be helpful when debugging clippy failures, but
* should generally be set to false to avoid invalidating the Cargo cache between
* every unit test run.
*/
fun TestWriterDelegator.compileAndTest(runClippy: Boolean = false) {
val stubModel = """
namespace fake
service Fake {
@ -183,6 +188,9 @@ fun TestWriterDelegator.compileAndTest() {
// cargo fmt errors are useless, ignore
}
"cargo test".runCommand(baseDir, mapOf("RUSTFLAGS" to "-A dead_code"))
if (runClippy) {
"cargo clippy".runCommand(baseDir)
}
}
// TODO: unify these test helpers a bit

View File

@ -11,12 +11,14 @@ import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
inline fun <reified T : Shape> Model.lookup(shapeId: String): T {
return this.expectShape(ShapeId.from(shapeId), T::class.java)
@ -42,6 +44,37 @@ fun StructureShape.hasStreamingMember(model: Model) = this.findStreamingMember(m
fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait<StreamingTrait>(model) != null
fun MemberShape.isStreaming(model: Model) = this.getMemberTrait(model, StreamingTrait::class.java).isPresent
fun UnionShape.isEventStream(): Boolean {
return hasTrait(StreamingTrait::class.java)
}
fun MemberShape.isEventStream(model: Model): Boolean {
return (model.expectShape(target) as? UnionShape)?.isEventStream() ?: false
}
fun MemberShape.isInputEventStream(model: Model): Boolean {
return isEventStream(model) && model.expectShape(container).hasTrait<SyntheticInputTrait>()
}
fun MemberShape.isOutputEventStream(model: Model): Boolean {
return isEventStream(model) && model.expectShape(container).hasTrait<SyntheticInputTrait>()
}
private fun Shape.hasEventStreamMember(model: Model): Boolean {
return members().any { it.isEventStream(model) }
}
fun OperationShape.isInputEventStream(model: Model): Boolean {
return input.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false)
}
fun OperationShape.isOutputEventStream(model: Model): Boolean {
return output.map { id -> model.expectShape(id).hasEventStreamMember(model) }.orElse(false)
}
fun OperationShape.isEventStream(model: Model): Boolean {
return isInputEventStream(model) || isOutputEventStream(model)
}
fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = operations.any { id ->
// Don't assume all of the looked up operation ids are operation shapes. Our
// synthetic input/output structure shapes can have the same name as an operation,
// as is the case with `kinesisanalytics`.
model.getShape(id).orNull()?.let { it is OperationShape && it.isEventStream(model) } ?: false
}
/*
* Returns the member of this structure targeted with streaming trait (if it exists).
*

View File

@ -101,7 +101,7 @@ class RequestBindingGeneratorTest {
stringHeader: String
}
""".asSmithyModel()
private val model = OperationNormalizer(baseModel).transformModel()
private val model = OperationNormalizer.transform(baseModel)
private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java)
private val inputShape = model.expectShape(operationShape.input.get(), StructureShape::class.java)

View File

@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
@ -65,7 +66,7 @@ class ResponseBindingGeneratorTest {
additional: String,
}
""".asSmithyModel()
private val model = OperationNormalizer(baseModel).transformModel()
private val model = OperationNormalizer.transform(baseModel)
private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java)
private val symbolProvider = testSymbolProvider(model)
private val testProtocolConfig: ProtocolConfig = testProtocolConfig(model)
@ -78,7 +79,9 @@ class ResponseBindingGeneratorTest {
.filter { it.location == HttpLocation.HEADER }
bindings.forEach { binding ->
val runtimeType = ResponseBindingGenerator(
testProtocolConfig, operationShape
RestJson(testProtocolConfig),
testProtocolConfig,
operationShape
).generateDeserializeHeaderFn(binding)
// little hack to force these functions to be generated
rust("// use #T;", runtimeType)

View File

@ -0,0 +1,38 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.CratesIo
import software.amazon.smithy.rust.codegen.rustlang.DependencyScope.Compile
class CodegenDelegatorTest {
@Test
fun testMergeDependencyFeatures() {
val merged = mergeDependencyFeatures(
listOf(
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf()),
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1")),
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f2")),
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")),
CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()),
CargoDependency("B", CratesIo("2"), Compile, optional = true, features = setOf()),
CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()),
CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()),
).shuffled()
)
merged shouldBe setOf(
CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")),
CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()),
CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()),
)
}
}

View File

@ -0,0 +1,92 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
class EventStreamSymbolProviderTest {
@Test
fun `it should adjust types for operations with event streams`() {
// Transform the model so that it has synthetic inputs/outputs
val model = OperationNormalizer.transform(
"""
namespace test
structure Something { stuff: Blob }
@streaming
union SomeStream {
Something: Something,
}
structure TestInput { inputStream: SomeStream }
structure TestOutput { outputStream: SomeStream }
operation TestOperation {
input: TestInput,
output: TestOutput,
}
service TestService { version: "123", operations: [TestOperation] }
""".asSmithyModel()
)
val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model)
// Look up the synthetic input/output rather than the original input/output
val inputStream = model.expectShape(ShapeId.from("test#TestOperationInput\$inputStream")) as MemberShape
val outputStream = model.expectShape(ShapeId.from("test#TestOperationOutput\$outputStream")) as MemberShape
val inputType = provider.toSymbol(inputStream).rustType()
val outputType = provider.toSymbol(outputStream).rustType()
inputType shouldBe RustType.Opaque("EventStreamInput<crate::model::SomeStream>", "smithy_http::event_stream")
outputType shouldBe RustType.Opaque("Receiver<crate::model::SomeStream, crate::error::TestOperationError>", "smithy_http::event_stream")
}
@Test
fun `it should leave alone types for operations without event streams`() {
val model = OperationNormalizer.transform(
"""
namespace test
structure Something { stuff: Blob }
union NotStreaming {
Something: Something,
}
structure TestInput { inputStream: NotStreaming }
structure TestOutput { outputStream: NotStreaming }
operation TestOperation {
input: TestInput,
output: TestOutput,
}
service TestService { version: "123", operations: [TestOperation] }
""".asSmithyModel()
)
val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model)
// Look up the synthetic input/output rather than the original input/output
val inputStream = model.expectShape(ShapeId.from("test#TestOperationInput\$inputStream")) as MemberShape
val outputStream = model.expectShape(ShapeId.from("test#TestOperationOutput\$outputStream")) as MemberShape
val inputType = provider.toSymbol(inputStream).rustType()
val outputType = provider.toSymbol(outputStream).rustType()
inputType shouldBe RustType.Option(RustType.Opaque("NotStreaming", "crate::model"))
outputType shouldBe RustType.Option(RustType.Opaque("NotStreaming", "crate::model"))
}
}

View File

@ -34,7 +34,7 @@ internal class StreamingShapeSymbolProviderTest {
fun `generates a byte stream on streaming output`() {
// we could test exactly the streaming shape symbol provider, but we actually care about is the full stack
// "doing the right thing"
val modelWithOperationTraits = OperationNormalizer(model).transformModel()
val modelWithOperationTraits = OperationNormalizer.transform(model)
val symbolProvider = testSymbolProvider(modelWithOperationTraits)
symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test#GenerateSpeechOutput\$data")).name shouldBe ("byte_stream::ByteStream")
symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test#GenerateSpeechInput\$data")).name shouldBe ("byte_stream::ByteStream")
@ -42,7 +42,7 @@ internal class StreamingShapeSymbolProviderTest {
@Test
fun `streaming members have a default`() {
val modelWithOperationTraits = OperationNormalizer(model).transformModel()
val modelWithOperationTraits = OperationNormalizer.transform(model)
val symbolProvider = testSymbolProvider(modelWithOperationTraits)
val outputSymbol = symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test#GenerateSpeechOutput\$data"))

View File

@ -22,8 +22,6 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.testutil.generatePluginContext
import software.amazon.smithy.rust.codegen.util.CommandFailed
@ -163,9 +161,7 @@ class HttpProtocolTestGeneratorTest {
return TestProtocol(protocolConfig)
}
override fun transformModel(model: Model): Model {
return OperationNormalizer(model).transformModel().let(RemoveEventStreamOperations::transform)
}
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(true, true, true, true)

View File

@ -0,0 +1,307 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.protocols
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.testutil.TestWriterDelegator
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.testutil.renderWithModelBuilder
import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.codegen.util.outputShape
import java.util.stream.Stream
private fun fillInBaseModel(
protocolName: String,
extraServiceAnnotations: String = "",
): String = """
namespace test
use aws.protocols#$protocolName
union TestUnion {
Foo: String,
Bar: Integer,
}
structure TestStruct {
someString: String,
someInt: Integer,
}
@error("client")
structure SomeError {
Message: String,
}
structure MessageWithBlob { @eventPayload data: Blob }
structure MessageWithString { @eventPayload data: String }
structure MessageWithStruct { @eventPayload someStruct: TestStruct }
structure MessageWithUnion { @eventPayload someUnion: TestUnion }
structure MessageWithHeaders {
@eventHeader blob: Blob,
@eventHeader boolean: Boolean,
@eventHeader byte: Byte,
@eventHeader int: Integer,
@eventHeader long: Long,
@eventHeader short: Short,
@eventHeader string: String,
@eventHeader timestamp: Timestamp,
}
structure MessageWithHeaderAndPayload {
@eventHeader header: String,
@eventPayload payload: Blob,
}
structure MessageWithNoHeaderPayloadTraits {
someInt: Integer,
someString: String,
}
@streaming
union TestStream {
MessageWithBlob: MessageWithBlob,
MessageWithString: MessageWithString,
MessageWithStruct: MessageWithStruct,
MessageWithUnion: MessageWithUnion,
MessageWithHeaders: MessageWithHeaders,
MessageWithHeaderAndPayload: MessageWithHeaderAndPayload,
MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits,
SomeError: SomeError,
}
structure TestStreamInputOutput { @required value: TestStream }
operation TestStreamOp {
input: TestStreamInputOutput,
output: TestStreamInputOutput,
errors: [SomeError],
}
$extraServiceAnnotations
@$protocolName
service TestService { version: "123", operations: [TestStreamOp] }
"""
object EventStreamTestModels {
fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
fun awsQuery(): Model = fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fun ec2Query(): Model = fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
data class TestCase(
val protocolShapeId: String,
val model: Model,
val contentType: String,
val validTestStruct: String,
val validMessageWithNoHeaderPayloadTraits: String,
val validTestUnion: String,
val validSomeError: String,
val protocolBuilder: (ProtocolConfig) -> Protocol,
) {
override fun toString(): String = protocolShapeId
}
class ModelArgumentsProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
Stream.of(
Arguments.of(
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
contentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
) { RestJson(it) }
),
Arguments.of(
TestCase(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
contentType = "application/x-amz-json-1.1",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) }
),
Arguments.of(
TestCase(
protocolShapeId = "aws.protocols#restXml",
model = restXml(),
contentType = "text/xml",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<ErrorResponse>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</ErrorResponse>
""".trimIndent()
) { RestXml(it) }
),
Arguments.of(
TestCase(
protocolShapeId = "aws.protocols#awsQuery",
model = awsQuery(),
contentType = "application/x-www-form-urlencoded",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<ErrorResponse>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</ErrorResponse>
""".trimIndent()
) { AwsQueryProtocol(it) }
),
Arguments.of(
TestCase(
protocolShapeId = "aws.protocols#ec2Query",
model = ec2Query(),
contentType = "application/x-www-form-urlencoded",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<Response>
<Errors>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</Error>
</Response>
""".trimIndent()
) { Ec2QueryProtocol(it) }
),
)
}
}
data class TestEventStreamProject(
val model: Model,
val serviceShape: ServiceShape,
val operationShape: OperationShape,
val streamShape: UnionShape,
val symbolProvider: RustSymbolProvider,
val project: TestWriterDelegator,
)
object EventStreamTestTools {
fun generateTestProject(model: Model): TestEventStreamProject {
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(model))
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape
val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape
val symbolProvider = testSymbolProvider(model)
val project = TestWorkspace.testProject(symbolProvider)
project.withModule(RustModule.default("error", public = true)) {
CombinedErrorGenerator(model, symbolProvider, operationShape).render(it)
for (shape in model.shapes().filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }) {
StructureGenerator(model, symbolProvider, it, shape as StructureShape).render()
val builderGen = BuilderGenerator(model, symbolProvider, shape)
builderGen.render(it)
it.implBlock(shape, symbolProvider) {
builderGen.renderConvenienceMethod(this)
}
}
}
project.withModule(RustModule.default("model", public = true)) {
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
recursivelyGenerateModels(model, symbolProvider, inputOutput, it)
}
project.withModule(RustModule.default("output", public = true)) {
operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, it)
}
println("file:///${project.baseDir}/src/error.rs")
println("file:///${project.baseDir}/src/event_stream.rs")
println("file:///${project.baseDir}/src/event_stream_serde.rs")
println("file:///${project.baseDir}/src/lib.rs")
println("file:///${project.baseDir}/src/model.rs")
return TestEventStreamProject(model, serviceShape, operationShape, unionShape, symbolProvider, project)
}
private fun recursivelyGenerateModels(
model: Model,
symbolProvider: RustSymbolProvider,
shape: Shape,
writer: RustWriter
) {
for (member in shape.members()) {
val target = model.expectShape(member.target)
if (target is StructureShape || target is UnionShape) {
if (target is StructureShape) {
target.renderWithModelBuilder(model, symbolProvider, writer)
} else if (target is UnionShape) {
UnionGenerator(model, symbolProvider, writer, target).render()
}
recursivelyGenerateModels(model, symbolProvider, target, writer)
}
}
}
}

View File

@ -42,7 +42,7 @@ class AwsQueryParserGeneratorTest {
@Test
fun `it modifies operation parsing to include Response and Result tags`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = AwsQueryParserGenerator(
testProtocolConfig(model),

View File

@ -42,7 +42,7 @@ class Ec2QueryParserGeneratorTest {
@Test
fun `it modifies operation parsing to include Response and Result tags`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = Ec2QueryParserGenerator(
testProtocolConfig(model),

View File

@ -0,0 +1,245 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.protocols.parse
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestModels
import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestTools
import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.testutil.unitTest
class EventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(EventStreamTestModels.ModelArgumentsProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
val test = EventStreamTestTools.generateTestProject(testCase.model)
val protocolConfig = ProtocolConfig(
test.model,
test.symbolProvider,
TestRuntimeConfig,
test.serviceShape,
ShapeId.from(testCase.protocolShapeId),
"test"
)
val protocol = testCase.protocolBuilder(protocolConfig)
val generator = EventStreamUnmarshallerGenerator(
protocol,
test.model,
TestRuntimeConfig,
test.symbolProvider,
test.operationShape,
test.streamShape
)
test.project.lib { writer ->
// TODO(EventStream): Add test for bad content type
// TODO(EventStream): Add test for generic error parsing
writer.rust(
"""
use smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage};
use smithy_types::{Blob, Instant};
use crate::error::*;
use crate::model::*;
fn msg(
message_type: &'static str,
event_type: &'static str,
content_type: &'static str,
payload: &'static [u8],
) -> Message {
let message = Message::new(payload)
.add_header(Header::new(":message-type", HeaderValue::String(message_type.into())))
.add_header(Header::new(":content-type", HeaderValue::String(content_type.into())));
if message_type == "event" {
message.add_header(Header::new(":event-type", HeaderValue::String(event_type.into())))
} else {
message.add_header(Header::new(":exception-type", HeaderValue::String(event_type.into())))
}
}
fn expect_event<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> T {
match unmarshalled {
UnmarshalledMessage::Event(event) => event,
_ => panic!("expected event, got: {:?}", unmarshalled),
}
}
fn expect_error<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> E {
match unmarshalled {
UnmarshalledMessage::Error(error) => error,
_ => panic!("expected error, got: {:?}", unmarshalled),
}
}
"""
)
writer.unitTest(
"""
let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!");
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
),
expect_event(result.unwrap())
);
""",
"message_with_blob",
)
writer.unitTest(
"""
let message = msg("event", "MessageWithString", "application/octet-stream", b"hello, world!");
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()),
expect_event(result.unwrap())
);
""",
"message_with_string",
)
writer.unitTest(
"""
let message = msg(
"event",
"MessageWithStruct",
"${testCase.contentType}",
br#"${testCase.validTestStruct}"#
);
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(
TestStruct::builder()
.some_string("hello")
.some_int(5)
.build()
).build()),
expect_event(result.unwrap())
);
""",
"message_with_struct",
)
writer.unitTest(
"""
let message = msg(
"event",
"MessageWithUnion",
"${testCase.contentType}",
br#"${testCase.validTestUnion}"#
);
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build()),
expect_event(result.unwrap())
);
""",
"message_with_union",
)
writer.unitTest(
"""
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(Instant::from_epoch_seconds(5))));
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(Instant::from_epoch_seconds(5))
.build()
),
expect_event(result.unwrap())
);
""",
"message_with_headers",
)
writer.unitTest(
"""
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
.add_header(Header::new("header", HeaderValue::String("header".into())));
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
),
expect_event(result.unwrap())
);
""",
"message_with_header_and_payload",
)
writer.unitTest(
"""
let message = msg(
"event",
"MessageWithNoHeaderPayloadTraits",
"${testCase.contentType}",
br#"${testCase.validMessageWithNoHeaderPayloadTraits}"#
);
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
),
expect_event(result.unwrap())
);
""",
"message_with_no_header_payload_traits",
)
writer.unitTest(
"""
let message = msg(
"exception",
"SomeError",
"${testCase.contentType}",
br#"${testCase.validSomeError}"#
);
let result = ${writer.format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap()).kind {
TestStreamOpErrorKind::SomeError(err) => assert_eq!(Some("some error"), err.message()),
kind => panic!("expected SomeError, but got {:?}", kind),
}
""",
"some_error",
)
}
test.project.compileAndTest()
}
}

View File

@ -106,7 +106,7 @@ class JsonParserGeneratorTest {
@Test
fun `generates valid deserializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = JsonParserGenerator(
testProtocolConfig(model),

View File

@ -90,7 +90,7 @@ internal class XmlBindingTraitParserGeneratorTest {
@Test
fun `generates valid parsers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = XmlBindingTraitParserGenerator(
testProtocolConfig(model),

View File

@ -81,7 +81,7 @@ class AwsQuerySerializerGeneratorTest {
@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = AwsQuerySerializerGenerator(testProtocolConfig(model))
val operationGenerator = parserGenerator.operationSerializer(model.lookup("test#Op"))

View File

@ -80,7 +80,7 @@ class Ec2QuerySerializerGeneratorTest {
@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = Ec2QuerySerializerGenerator(testProtocolConfig(model))
val operationGenerator = parserGenerator.operationSerializer(model.lookup("test#Op"))

View File

@ -98,7 +98,7 @@ class JsonSerializerGeneratorTest {
@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserSerializer = JsonSerializerGenerator(
testProtocolConfig(model),

View File

@ -103,7 +103,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel())
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val symbolProvider = testSymbolProvider(model)
val parserGenerator = XmlBindingTraitSerializerGenerator(
testProtocolConfig(model),
@ -124,7 +124,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
.build()
).build().unwrap();
let serialized = ${writer.format(operationParser)}(&inp.payload.unwrap()).unwrap();
let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap();
let output = std::str::from_utf8(&serialized).unwrap();
assert_eq!(output, "<Top extra=\"45\"><field>hello!</field><recursive extra=\"55\"></recursive></Top>");
"""
)

View File

@ -0,0 +1,63 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.transformers
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticEventStreamUnionTrait
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
class EventStreamNormalizerTest {
@Test
fun `it should leave normal unions alone`() {
val transformed = EventStreamNormalizer.transform(
"""
namespace test
union SomeNormalUnion {
Foo: String,
Bar: Long,
}
""".asSmithyModel()
)
val shape = transformed.expectShape(ShapeId.from("test#SomeNormalUnion"), UnionShape::class.java)
shape.hasTrait<SyntheticEventStreamUnionTrait>() shouldBe false
shape.memberNames shouldBe listOf("Foo", "Bar")
}
@Test
fun `it should transform event stream unions`() {
val transformed = EventStreamNormalizer.transform(
"""
namespace test
structure SomeMember {
}
@error("client")
structure SomeError {
}
@streaming
union SomeEventStream {
SomeMember: SomeMember,
SomeError: SomeError,
}
""".asSmithyModel()
)
val shape = transformed.expectShape(ShapeId.from("test#SomeEventStream"), UnionShape::class.java)
shape.hasTrait<SyntheticEventStreamUnionTrait>() shouldBe true
shape.memberNames shouldBe listOf("SomeMember")
val trait = shape.expectTrait<SyntheticEventStreamUnionTrait>()
trait.errorMembers.map { it.memberName } shouldBe listOf("SomeError")
}
}

View File

@ -26,8 +26,7 @@ internal class OperationNormalizerTest {
""".asSmithyModel()
val operationId = ShapeId.from("smithy.test#Empty")
model.expectShape(operationId, OperationShape::class.java).input.isPresent shouldBe false
val sut = OperationNormalizer(model)
val modified = sut.transformModel()
val modified = OperationNormalizer.transform(model)
val operation = modified.expectShape(operationId, OperationShape::class.java)
operation.input.isPresent shouldBe true
operation.input.get().name shouldBe "EmptyInput"
@ -55,8 +54,7 @@ internal class OperationNormalizerTest {
""".asSmithyModel()
val operationId = ShapeId.from("smithy.test#MyOp")
model.expectShape(operationId, OperationShape::class.java).input.isPresent shouldBe true
val sut = OperationNormalizer(model)
val modified = sut.transformModel()
val modified = OperationNormalizer.transform(model)
val operation = modified.expectShape(operationId, OperationShape::class.java)
operation.input.isPresent shouldBe true
val inputId = operation.input.get()
@ -79,8 +77,7 @@ internal class OperationNormalizerTest {
""".asSmithyModel()
val operationId = ShapeId.from("smithy.test#MyOp")
model.expectShape(operationId, OperationShape::class.java).output.isPresent shouldBe true
val sut = OperationNormalizer(model)
val modified = sut.transformModel()
val modified = OperationNormalizer.transform(model)
val operation = modified.expectShape(operationId, OperationShape::class.java)
operation.output.isPresent shouldBe true
val outputId = operation.output.get()

View File

@ -11,8 +11,8 @@ are to allow this crate to be compilable and testable in isolation, no client co
[dependencies]
"bytes" = "1"
"http" = "0.2.1"
"smithy-types" = { version = "0.0.1", path = "../smithy-types" }
"smithy-http" = { version = "0.0.1", path = "../smithy-http" }
"smithy-types" = { path = "../smithy-types" }
"smithy-http" = { path = "../smithy-http" }
"smithy-json" = { path = "../smithy-json" }
"smithy-query" = { path = "../smithy-query" }
"smithy-xml" = { path = "../smithy-xml" }

View File

@ -23,6 +23,8 @@ pub enum Error {
PayloadTooLong,
PreludeChecksumMismatch(u32, u32),
TimestampValueTooLarge(Instant),
Marshalling(String),
Unmarshalling(String),
}
impl StdError for Error {}
@ -56,6 +58,8 @@ impl fmt::Display for Error {
"timestamp value {:?} is too large to fit into an i64",
time
),
Marshalling(error) => write!(f, "failed to marshall message: {}", error),
Unmarshalling(error) => write!(f, "failed to unmarshall message: {}", error),
}
}
}

View File

@ -12,6 +12,7 @@ use crate::str_bytes::StrBytes;
use bytes::{Buf, BufMut, Bytes};
use std::convert::{TryFrom, TryInto};
use std::error::Error as StdError;
use std::fmt;
use std::mem::size_of;
const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
@ -23,24 +24,36 @@ const MIN_HEADER_LEN: usize = 2;
pub type SignMessageError = Box<dyn StdError + Send + Sync + 'static>;
/// Signs an Event Stream message.
pub trait SignMessage {
pub trait SignMessage: fmt::Debug {
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError>;
}
/// Converts a Smithy modeled Event Stream type into a [`Message`](Message).
pub trait MarshallMessage {
pub trait MarshallMessage: fmt::Debug {
/// Smithy modeled input type to convert from.
type Input;
fn marshall(&self, input: Self::Input) -> Result<Message, Error>;
}
/// A successfully unmarshalled message that is either an `Event` or an `Error`.
#[derive(Debug)]
pub enum UnmarshalledMessage<T, E> {
Event(T),
Error(E),
}
/// Converts an Event Stream [`Message`](Message) into a Smithy modeled type.
pub trait UnmarshallMessage {
pub trait UnmarshallMessage: fmt::Debug {
/// Smithy modeled type to convert into.
type Output;
/// Smithy modeled error to convert into.
type Error;
fn unmarshall(&self, message: Message) -> Result<Self::Output, Error>;
fn unmarshall(
&self,
message: &Message,
) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, Error>;
}
mod value {
@ -68,7 +81,7 @@ mod value {
#[derive(Clone, Debug, PartialEq)]
pub enum HeaderValue {
Bool(bool),
Byte(u8),
Byte(i8),
Int16(i16),
Int32(i32),
Int64(i64),
@ -78,6 +91,71 @@ mod value {
Uuid(u128),
}
impl HeaderValue {
pub fn as_bool(&self) -> Result<bool, &Self> {
match self {
HeaderValue::Bool(value) => Ok(*value),
_ => Err(self),
}
}
pub fn as_byte(&self) -> Result<i8, &Self> {
match self {
HeaderValue::Byte(value) => Ok(*value),
_ => Err(self),
}
}
pub fn as_int16(&self) -> Result<i16, &Self> {
match self {
HeaderValue::Int16(value) => Ok(*value),
_ => Err(self),
}
}
pub fn as_int32(&self) -> Result<i32, &Self> {
match self {
HeaderValue::Int32(value) => Ok(*value),
_ => Err(self),
}
}
pub fn as_int64(&self) -> Result<i64, &Self> {
match self {
HeaderValue::Int64(value) => Ok(*value),
_ => Err(self),
}
}
pub fn as_byte_array(&self) -> Result<&Bytes, &Self> {
match self {
HeaderValue::ByteArray(value) => Ok(value),
_ => Err(self),
}
}
pub fn as_string(&self) -> Result<&StrBytes, &Self> {
match self {
HeaderValue::String(value) => Ok(value),
_ => Err(self),
}
}
pub fn as_timestamp(&self) -> Result<Instant, &Self> {
match self {
HeaderValue::Timestamp(value) => Ok(*value),
_ => Err(self),
}
}
pub fn as_uuid(&self) -> Result<u128, &Self> {
match self {
HeaderValue::Uuid(value) => Ok(*value),
_ => Err(self),
}
}
}
macro_rules! read_value {
($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => {
if $buf.remaining() >= size_of::<$size_typ>() {
@ -94,7 +172,7 @@ mod value {
match value_type {
TYPE_TRUE => Ok(HeaderValue::Bool(true)),
TYPE_FALSE => Ok(HeaderValue::Bool(false)),
TYPE_BYTE => read_value!(buffer, Byte, u8, get_u8),
TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8),
TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16),
TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32),
TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64),
@ -137,7 +215,7 @@ mod value {
Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }),
Byte(val) => {
buffer.put_u8(TYPE_BYTE);
buffer.put_u8(*val);
buffer.put_i8(*val);
}
Int16(val) => {
buffer.put_u8(TYPE_INT16);
@ -184,7 +262,7 @@ mod value {
Ok(match value_type {
TYPE_TRUE => HeaderValue::Bool(true),
TYPE_FALSE => HeaderValue::Bool(false),
TYPE_BYTE => HeaderValue::Byte(u8::arbitrary(unstruct)?),
TYPE_BYTE => HeaderValue::Byte(i8::arbitrary(unstruct)?),
TYPE_INT16 => HeaderValue::Int16(i16::arbitrary(unstruct)?),
TYPE_INT32 => HeaderValue::Int32(i32::arbitrary(unstruct)?),
TYPE_INT64 => HeaderValue::Int64(i64::arbitrary(unstruct)?),
@ -289,6 +367,14 @@ impl Message {
}
}
/// Creates a message with the given `headers` and `payload`.
pub fn new_from_parts(headers: Vec<Header>, payload: impl Into<Bytes>) -> Self {
Self {
headers,
payload: payload.into(),
}
}
/// Adds a header to the message.
pub fn add_header(mut self, header: Header) -> Self {
self.headers.push(header);
@ -609,7 +695,7 @@ pub enum DecodedFrame {
/// Streaming decoder for decoding a [`Message`] from a stream.
#[non_exhaustive]
#[derive(Default)]
#[derive(Default, Debug)]
pub struct MessageFrameDecoder {
prelude: [u8; PRELUDE_LENGTH_BYTES_USIZE],
prelude_read: bool,

View File

@ -8,4 +8,5 @@
mod buf;
pub mod error;
pub mod frame;
pub mod smithy;
pub mod str_bytes;

View File

@ -0,0 +1,179 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
use crate::error::Error;
use crate::frame::{Header, HeaderValue, Message};
use crate::str_bytes::StrBytes;
use smithy_types::{Blob, Instant};
macro_rules! expect_shape_fn {
(fn $fn_name:ident[$val_typ:ident] -> $result_typ:ident { $val_name:ident -> $val_expr:expr }) => {
pub fn $fn_name(header: &Header) -> Result<$result_typ, Error> {
match header.value() {
HeaderValue::$val_typ($val_name) => Ok($val_expr),
_ => Err(Error::Unmarshalling(format!(
"expected '{}' header value to be {}",
header.name().as_str(),
stringify!($val_typ)
))),
}
}
};
}
expect_shape_fn!(fn expect_bool[Bool] -> bool { value -> *value });
expect_shape_fn!(fn expect_byte[Byte] -> i8 { value -> *value });
expect_shape_fn!(fn expect_int16[Int16] -> i16 { value -> *value });
expect_shape_fn!(fn expect_int32[Int32] -> i32 { value -> *value });
expect_shape_fn!(fn expect_int64[Int64] -> i64 { value -> *value });
expect_shape_fn!(fn expect_byte_array[ByteArray] -> Blob { bytes -> Blob::new(bytes.as_ref()) });
expect_shape_fn!(fn expect_string[String] -> String { value -> value.as_str().into() });
expect_shape_fn!(fn expect_timestamp[Timestamp] -> Instant { value -> *value });
pub struct ResponseHeaders<'a> {
pub content_type: &'a StrBytes,
pub message_type: &'a StrBytes,
pub smithy_type: &'a StrBytes,
}
fn expect_header_str_value<'a>(
header: Option<&'a Header>,
name: &str,
) -> Result<&'a StrBytes, Error> {
match header {
Some(header) => Ok(header.value().as_string().map_err(|value| {
Error::Unmarshalling(format!(
"expected response {} header to be string, received {:?}",
name, value
))
})?),
None => Err(Error::Unmarshalling(format!(
"expected response to include {} header, but it was missing",
name
))),
}
}
pub fn parse_response_headers(message: &Message) -> Result<ResponseHeaders, Error> {
let (mut content_type, mut message_type, mut event_type, mut exception_type) =
(None, None, None, None);
for header in message.headers() {
match header.name().as_str() {
":content-type" => content_type = Some(header),
":message-type" => message_type = Some(header),
":event-type" => event_type = Some(header),
":exception-type" => exception_type = Some(header),
_ => {}
}
}
let message_type = expect_header_str_value(message_type, ":message-type")?;
Ok(ResponseHeaders {
content_type: expect_header_str_value(content_type, ":content-type")?,
message_type,
smithy_type: if message_type.as_str() == "event" {
expect_header_str_value(event_type, ":event-type")?
} else if message_type.as_str() == "exception" {
expect_header_str_value(exception_type, ":exception-type")?
} else {
return Err(Error::Unmarshalling(format!(
"unrecognized `:message-type`: {}",
message_type.as_str()
)));
},
})
}
#[cfg(test)]
mod tests {
use super::parse_response_headers;
use crate::frame::{Header, HeaderValue, Message};
#[test]
fn normal_message() {
let message = Message::new(&b"test"[..])
.add_header(Header::new(
":event-type",
HeaderValue::String("Foo".into()),
))
.add_header(Header::new(
":content-type",
HeaderValue::String("application/json".into()),
))
.add_header(Header::new(
":message-type",
HeaderValue::String("event".into()),
));
let parsed = parse_response_headers(&message).unwrap();
assert_eq!("Foo", parsed.smithy_type.as_str());
assert_eq!("application/json", parsed.content_type.as_str());
assert_eq!("event", parsed.message_type.as_str());
}
#[test]
fn error_message() {
let message = Message::new(&b"test"[..])
.add_header(Header::new(
":exception-type",
HeaderValue::String("BadRequestException".into()),
))
.add_header(Header::new(
":content-type",
HeaderValue::String("application/json".into()),
))
.add_header(Header::new(
":message-type",
HeaderValue::String("exception".into()),
));
let parsed = parse_response_headers(&message).unwrap();
assert_eq!("BadRequestException", parsed.smithy_type.as_str());
assert_eq!("application/json", parsed.content_type.as_str());
assert_eq!("exception", parsed.message_type.as_str());
}
#[test]
fn missing_exception_type() {
let message = Message::new(&b"test"[..])
.add_header(Header::new(
":content-type",
HeaderValue::String("application/json".into()),
))
.add_header(Header::new(
":message-type",
HeaderValue::String("exception".into()),
));
let error = parse_response_headers(&message).err().unwrap().to_string();
assert_eq!("failed to unmarshall message: expected response to include :exception-type header, but it was missing", error);
}
#[test]
fn missing_event_type() {
let message = Message::new(&b"test"[..])
.add_header(Header::new(
":content-type",
HeaderValue::String("application/json".into()),
))
.add_header(Header::new(
":message-type",
HeaderValue::String("event".into()),
));
let error = parse_response_headers(&message).err().unwrap().to_string();
assert_eq!("failed to unmarshall message: expected response to include :event-type header, but it was missing", error);
}
#[test]
fn missing_content_type() {
let message = Message::new(&b"test"[..])
.add_header(Header::new(
":event-type",
HeaderValue::String("Foo".into()),
))
.add_header(Header::new(
":message-type",
HeaderValue::String("event".into()),
));
let error = parse_response_headers(&message).err().unwrap().to_string();
assert_eq!("failed to unmarshall message: expected response to include :content-type header, but it was missing", error);
}
}

View File

@ -8,7 +8,7 @@ license = "Apache-2.0"
[features]
bytestream-util = ["tokio/fs", "tokio-util/io"]
event-stream = ["smithy-eventstream"]
default = ["bytestream-util", "event-stream"]
default = ["bytestream-util"]
[dependencies]
smithy-types = { path = "../smithy-types" }

View File

@ -13,13 +13,50 @@ use futures_core::Stream;
use hyper::body::HttpBody;
use pin_project::pin_project;
use smithy_eventstream::frame::{
DecodedFrame, MarshallMessage, MessageFrameDecoder, SignMessage, UnmarshallMessage,
DecodedFrame, MarshallMessage, Message, MessageFrameDecoder, SignMessage, UnmarshallMessage,
UnmarshalledMessage,
};
use std::error::Error as StdError;
use std::fmt;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
pub type BoxError = Box<dyn StdError + Send + Sync + 'static>;
/// Input type for Event Streams.
pub struct EventStreamInput<T> {
input_stream: Pin<Box<dyn Stream<Item = Result<T, BoxError>> + Send>>,
}
impl<T> fmt::Debug for EventStreamInput<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "EventStreamInput(Box<dyn Stream>)")
}
}
impl<T> EventStreamInput<T> {
#[doc(hidden)]
pub fn into_body_stream<E: StdError + Send + Sync + 'static>(
self,
marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
signer: impl SignMessage + Send + Sync + 'static,
) -> MessageStreamAdapter<T, E> {
MessageStreamAdapter::new(marshaller, signer, self.input_stream)
}
}
impl<T, S> From<S> for EventStreamInput<T>
where
S: Stream<Item = Result<T, BoxError>> + Send + 'static,
{
fn from(stream: S) -> Self {
EventStreamInput {
input_stream: Box::pin(stream),
}
}
}
/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
/// message marshaller and signer implementations.
///
@ -30,24 +67,32 @@ pub struct MessageStreamAdapter<T, E> {
marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
signer: Box<dyn SignMessage + Send + Sync>,
#[pin]
stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
stream: Pin<Box<dyn Stream<Item = Result<T, BoxError>> + Send>>,
_phantom: PhantomData<E>,
}
impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
impl<T, E> MessageStreamAdapter<T, E>
where
E: StdError + Send + Sync + 'static,
{
pub fn new(
marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
signer: impl SignMessage + Send + Sync + 'static,
stream: impl Stream<Item = Result<T, E>> + Send + Sync + 'static,
stream: Pin<Box<dyn Stream<Item = Result<T, BoxError>> + Send>>,
) -> Self {
MessageStreamAdapter {
marshaller: Box::new(marshaller),
signer: Box::new(signer),
stream: Box::pin(stream),
stream,
_phantom: Default::default(),
}
}
}
impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
impl<T, E> Stream for MessageStreamAdapter<T, E>
where
E: StdError + Send + Sync + 'static,
{
type Item = Result<Bytes, SdkError<E>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@ -56,7 +101,7 @@ impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T,
Poll::Ready(message_option) => {
if let Some(message_result) = message_option {
let message_result =
message_result.map_err(|err| SdkError::ConstructionFailure(Box::new(err)));
message_result.map_err(|err| SdkError::ConstructionFailure(err));
let message = this
.marshaller
.marshall(message_result?)
@ -80,17 +125,21 @@ impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T,
}
/// Receives Smithy-modeled messages out of an Event Stream.
pub struct Receiver<T, E: StdError + Send + Sync> {
unmarshaller: Box<dyn UnmarshallMessage<Output = T>>,
#[derive(Debug)]
pub struct Receiver<T, E> {
unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E>>,
decoder: MessageFrameDecoder,
buffer: SegmentedBuf<Bytes>,
body: SdkBody,
_phantom: PhantomData<E>,
}
impl<T, E: StdError + Send + Sync> Receiver<T, E> {
impl<T, E> Receiver<T, E> {
/// Creates a new `Receiver` with the given message unmarshaller and SDK body.
pub fn new(unmarshaller: impl UnmarshallMessage<Output = T> + 'static, body: SdkBody) -> Self {
pub fn new(
unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + 'static,
body: SdkBody,
) -> Self {
Receiver {
unmarshaller: Box::new(unmarshaller),
decoder: MessageFrameDecoder::new(),
@ -104,7 +153,7 @@ impl<T, E: StdError + Send + Sync> Receiver<T, E> {
/// it returns an `Ok(None)`. If there is a transport layer error, it will return
/// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
/// messages.
pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E>> {
pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, Message>> {
let next_chunk = self
.body
.data()
@ -119,11 +168,16 @@ impl<T, E: StdError + Send + Sync> Receiver<T, E> {
.decode_frame(&mut self.buffer)
.map_err(|err| SdkError::DispatchFailure(Box::new(err)))?
{
return Ok(Some(
self.unmarshaller
.unmarshall(message)
.map_err(|err| SdkError::DispatchFailure(Box::new(err)))?,
));
return match self
.unmarshaller
.unmarshall(&message)
.map_err(|err| SdkError::DispatchFailure(Box::new(err)))?
{
UnmarshalledMessage::Event(event) => Ok(Some(event)),
UnmarshalledMessage::Error(err) => {
Err(SdkError::ServiceError { err, raw: message })
}
};
}
}
Ok(None)
@ -134,7 +188,7 @@ impl<T, E: StdError + Send + Sync> Receiver<T, E> {
mod tests {
use super::{MarshallMessage, Receiver, UnmarshallMessage};
use crate::body::SdkBody;
use crate::event_stream::MessageStreamAdapter;
use crate::event_stream::{EventStreamInput, MessageStreamAdapter};
use crate::result::SdkError;
use async_stream::stream;
use bytes::Bytes;
@ -142,7 +196,9 @@ mod tests {
use futures_util::stream::StreamExt;
use hyper::body::Body;
use smithy_eventstream::error::Error as EventStreamError;
use smithy_eventstream::frame::{Header, HeaderValue, Message, SignMessage, SignMessageError};
use smithy_eventstream::frame::{
Header, HeaderValue, Message, SignMessage, SignMessageError, UnmarshalledMessage,
};
use std::error::Error as StdError;
use std::io::{Error as IOError, ErrorKind};
@ -164,25 +220,31 @@ mod tests {
impl StdError for FakeError {}
#[derive(Debug, Eq, PartialEq)]
struct UnmarshalledMessage(String);
struct TestMessage(String);
#[derive(Debug)]
struct Marshaller;
impl MarshallMessage for Marshaller {
type Input = UnmarshalledMessage;
type Input = TestMessage;
fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
Ok(Message::new(input.0.as_bytes().to_vec()))
}
}
#[derive(Debug)]
struct Unmarshaller;
impl UnmarshallMessage for Unmarshaller {
type Output = UnmarshalledMessage;
type Output = TestMessage;
type Error = EventStreamError;
fn unmarshall(&self, message: Message) -> Result<Self::Output, EventStreamError> {
Ok(UnmarshalledMessage(
fn unmarshall(
&self,
message: &Message,
) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
Ok(UnmarshalledMessage::Event(TestMessage(
std::str::from_utf8(&message.payload()[..]).unwrap().into(),
))
)))
}
}
@ -192,14 +254,13 @@ mod tests {
vec![Ok(encode_message("one")), Ok(encode_message("two"))];
let chunk_stream = futures_util::stream::iter(chunks);
let body = SdkBody::from(Body::wrap_stream(chunk_stream));
let mut receiver =
Receiver::<UnmarshalledMessage, EventStreamError>::new(Unmarshaller, body);
let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
assert_eq!(
UnmarshalledMessage("one".into()),
TestMessage("one".into()),
receiver.recv().await.unwrap().unwrap()
);
assert_eq!(
UnmarshalledMessage("two".into()),
TestMessage("two".into()),
receiver.recv().await.unwrap().unwrap()
);
}
@ -212,10 +273,9 @@ mod tests {
];
let chunk_stream = futures_util::stream::iter(chunks);
let body = SdkBody::from(Body::wrap_stream(chunk_stream));
let mut receiver =
Receiver::<UnmarshalledMessage, EventStreamError>::new(Unmarshaller, body);
let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
assert_eq!(
UnmarshalledMessage("one".into()),
TestMessage("one".into()),
receiver.recv().await.unwrap().unwrap()
);
assert!(matches!(
@ -234,10 +294,9 @@ mod tests {
];
let chunk_stream = futures_util::stream::iter(chunks);
let body = SdkBody::from(Body::wrap_stream(chunk_stream));
let mut receiver =
Receiver::<UnmarshalledMessage, EventStreamError>::new(Unmarshaller, body);
let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
assert_eq!(
UnmarshalledMessage("one".into()),
TestMessage("one".into()),
receiver.recv().await.unwrap().unwrap()
);
assert!(matches!(
@ -246,6 +305,16 @@ mod tests {
));
}
#[derive(Debug)]
struct TestServiceError;
impl std::fmt::Display for TestServiceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TestServiceError")
}
}
impl StdError for TestServiceError {}
#[derive(Debug)]
struct TestSigner;
impl SignMessage for TestSigner {
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
@ -267,12 +336,15 @@ mod tests {
#[tokio::test]
async fn message_stream_adapter_success() {
let stream = stream! {
yield Ok(UnmarshalledMessage("test".into()));
yield Ok(TestMessage("test".into()));
};
let mut adapter =
check_compatible_with_hyper_wrap_stream(
MessageStreamAdapter::<_, EventStreamError>::new(Marshaller, TestSigner, stream),
);
check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
TestMessage,
TestServiceError,
>::new(
Marshaller, TestSigner, Box::pin(stream)
));
let mut sent_bytes = adapter.next().await.unwrap().unwrap();
let sent = Message::read_from(&mut sent_bytes).unwrap();
@ -285,12 +357,15 @@ mod tests {
#[tokio::test]
async fn message_stream_adapter_construction_failure() {
let stream = stream! {
yield Err(EventStreamError::InvalidMessageLength);
yield Err(EventStreamError::InvalidMessageLength.into());
};
let mut adapter =
check_compatible_with_hyper_wrap_stream(
MessageStreamAdapter::<UnmarshalledMessage, _>::new(Marshaller, TestSigner, stream),
);
check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
TestMessage,
TestServiceError,
>::new(
Marshaller, TestSigner, Box::pin(stream)
));
let result = adapter.next().await.unwrap();
assert!(result.is_err());
@ -299,4 +374,18 @@ mod tests {
SdkError::ConstructionFailure(_)
));
}
// Verify the developer experience for this compiles
#[allow(unused)]
fn event_stream_input_ergonomics() {
fn check(input: impl Into<EventStreamInput<TestMessage>>) {
let _: EventStreamInput<TestMessage> = input.into();
}
check(stream! {
yield Ok(TestMessage("test".into()));
});
check(stream! {
yield Err(EventStreamError::InvalidMessageLength.into());
});
}
}

View File

@ -159,6 +159,11 @@ impl Request {
}
}
/// Creates a new operation `Request` from its parts.
pub fn from_parts(inner: http::Request<SdkBody>, properties: Arc<Mutex<PropertyBag>>) -> Self {
Request { inner, properties }
}
/// Allows modification of the HTTP request and associated properties with a fallible closure.
pub fn augment<T>(
self,

View File

@ -19,7 +19,7 @@ pub struct SdkSuccess<O> {
/// Failed SDK Result
#[derive(Debug)]
pub enum SdkError<E> {
pub enum SdkError<E, R = operation::Response> {
/// The request failed during construction. It was not dispatched over the network.
ConstructionFailure(BoxError),
@ -35,10 +35,10 @@ pub enum SdkError<E> {
},
/// An error response was received from the service
ServiceError { err: E, raw: operation::Response },
ServiceError { err: E, raw: R },
}
impl<E> Display for SdkError<E>
impl<E, R> Display for SdkError<E, R>
where
E: Error,
{