Refactor event stream tests with `{client,server}IntegrationTest`s (#2342)

* Refactor `ClientEventStreamUnmarshallerGeneratorTest` to use `clientIntegrationTest` (WIP)

* Refactor `ClientEventStreamUnmarshallerGeneratorTest` with `clientIntegrationTest`

* Refactor `ClientEventStreamUnmarshallerGeneratorTest` to use generic test cases

* Start refactoring `ServerEventStreamUnmarshallerGeneratorTest`

* Make `ServerEventStreamUnmarshallerGeneratorTest` tests work

* Uncomment other test models

* Allow unused on `parse_generic_error`

* Rename `ServerEventStreamUnmarshallerGeneratorTest`

* Make `EventStreamUnmarshallTestCases` codegenTarget-agnostic

* Refactor `ClientEventStreamMarshallerGeneratorTest`: Tests run but fail

* Refactor `ServerEventStreamMarshallerGeneratorTest`

* Move `.into()` calls to `conditionalBuilderInput`

* Add "context" to TODO

* Fix client unmarshall tests

* Fix clippy lint

* Fix more clippy lints

* Add docs for `event_stream_serde` module

* Fix client tests

* Remove `#[allow(missing_docs)]` from event stream module

* Remove unused `EventStreamTestTools`

* Add `smithy-validation-model` test dep to `codegen-client`

* Temporarily add docs to make tests compile

* Undo change in model

* Make event stream unmarshaller tests a unit test

* Remove unused code

* Make `ServerEventStreamUnmarshallerGeneratorTest` a unit test

* Make `ServerEventStreamMarshallerGeneratorTest` a unit test

* Make `ServerEventStreamMarshallerGeneratorTest` pass

* Make remaining tests non-integration tests

* Make event stream serde module private again

* Remove unnecessary clippy allowances

* Remove clippy allowance

* Remove docs for `event_stream_serde` module

* Remove docs for `$unmarshallerTypeName::new`

* Remove more unnecessary docs

* Remove more superfluous docs

* Undo unnecessary diffs

* Uncomment last test

* Make `conditionalBuilderInput` internal
This commit is contained in:
Julian Antonielli 2023-02-28 20:26:20 +00:00 committed by GitHub
parent 72df8440c0
commit c3ae6f7eaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 538 additions and 895 deletions

View File

@ -28,6 +28,10 @@ dependencies {
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion")
// `smithy.framework#ValidationException` is defined here, which is used in event stream
// marshalling/unmarshalling tests.
testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion")
}
tasks.compileKotlin {

View File

@ -1,98 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator
import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import java.util.stream.Stream
class TestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream()
}
abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<ClientCodegenContext> {
override fun createCodegenContext(
model: Model,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): ClientCodegenContext = ClientCodegenContext(
model,
testSymbolProvider(model),
serviceShape,
protocolShapeId,
testClientRustSettings(),
CombinedClientCodegenDecorator(emptyList()),
)
override fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
) {
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, emptyList()).apply {
render(writer)
}
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
BuilderGenerator.renderConvenienceMethod(writer, codegenContext.symbolProvider, shape)
}
}
override fun renderOperationError(
writer: RustWriter,
model: Model,
symbolProvider: RustSymbolProvider,
operationOrEventStream: Shape,
) {
OperationErrorGenerator(model, symbolProvider, operationOrEventStream, emptyList()).render(writer)
}
override fun renderError(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
) {
val errorTrait = shape.expectTrait<ErrorTrait>()
val errorGenerator = ErrorGenerator(
codegenContext.model,
codegenContext.symbolProvider,
shape,
errorTrait,
emptyList(),
)
rustCrate.useShapeWriter(shape) {
errorGenerator.renderStruct(this)
}
rustCrate.withModule(codegenContext.symbolProvider.moduleForBuilder(shape)) {
errorGenerator.renderBuilder(this)
}
}
}

View File

@ -5,43 +5,30 @@
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import java.util.stream.Stream
class ClientEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.setupTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
codegenContext: ClientCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType = EventStreamMarshallerGenerator(
project.model,
CodegenTarget.CLIENT,
TestRuntimeConfig,
project.symbolProvider,
project.streamShape,
protocol.structuredDataSerializer(project.operationShape),
testCase.requestContentType,
).render()
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Marshall,
).compileAndTest()
clientIntegrationTest(testCase.model) { _, rustCrate ->
rustCrate.testModule {
writeMarshallTestCases(testCase, optionalBuilderInputs = false)
}
}
}
}
class TestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream()
}

View File

@ -7,39 +7,60 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
class ClientEventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.setupTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
codegenContext: ClientCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType {
return EventStreamUnmarshallerGenerator(
protocol,
codegenContext,
project.operationShape,
project.streamShape,
).render()
}
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Unmarshall,
).compileAndTest()
clientIntegrationTest(
testCase.model,
IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true),
) { _, rustCrate ->
val generator = "crate::event_stream_serde::TestStreamUnmarshaller"
rustCrate.testModule {
rust("##![allow(unused_imports, dead_code)]")
writeUnmarshallTestCases(testCase, optionalBuilderInputs = false)
unitTest(
"unknown_message",
"""
let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!");
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert!(expect_event(result.unwrap()).is_unknown());
""",
)
unitTest(
"generic_error",
"""
let message = msg(
"exception",
"UnmodeledError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap()) {
TestStreamError::Unhandled(err) => {
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
let expected = "message: \"unmodeled error\"";
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
}
kind => panic!("expected generic error, but got {:?}", kind),
}
""",
)
}
}
}
}

View File

@ -470,9 +470,11 @@ class Attribute(val inner: Writable) {
val AllowDeprecated = Attribute(allow("deprecated"))
val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns"))
val AllowUnreachableCode = Attribute(allow("unreachable_code"))
val AllowUnreachablePatterns = Attribute(allow("unreachable_patterns"))
val AllowUnusedImports = Attribute(allow("unused_imports"))
val AllowUnusedMut = Attribute(allow("unused_mut"))
val AllowUnusedVariables = Attribute(allow("unused_variables"))
val AllowMissingDocs = Attribute(allow("missing_docs"))
val CfgTest = Attribute(cfg("test"))
val DenyMissingDocs = Attribute(deny("missing_docs"))
val DocHidden = Attribute(doc("hidden"))

View File

@ -116,7 +116,7 @@ private fun <T : AbstractCodeWriter<T>, U> T.withTemplate(
* This enables conditionally wrapping a block in a prefix/suffix, e.g.
*
* ```
* writer.withBlock("Some(", ")", conditional = symbol.isOptional()) {
* writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) {
* write("symbolValue")
* }
* ```

View File

@ -43,6 +43,9 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule =
private("event_stream_serde")
class EventStreamUnmarshallerGenerator(
private val protocol: Protocol,
codegenContext: CodegenContext,
@ -60,7 +63,7 @@ class EventStreamUnmarshallerGenerator(
symbolProvider.symbolForEventStreamError(unionShape)
}
private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig)
private val eventStreamSerdeModule = RustModule.private("event_stream_serde")
private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule()
private val codegenScope = arrayOf(
"Blob" to RuntimeType.blob(runtimeConfig),
"expect_fns" to smithyEventStream.resolve("smithy"),
@ -84,15 +87,16 @@ class EventStreamUnmarshallerGenerator(
}
private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) {
val unmarshallerTypeName = unmarshallerType.name
rust(
"""
##[non_exhaustive]
##[derive(Debug)]
pub struct ${unmarshallerType.name};
pub struct $unmarshallerTypeName;
impl ${unmarshallerType.name} {
impl $unmarshallerTypeName {
pub fn new() -> Self {
${unmarshallerType.name}
$unmarshallerTypeName
}
}
""",
@ -154,6 +158,7 @@ class EventStreamUnmarshallerGenerator(
"Output" to unionSymbol,
*codegenScope,
)
false -> rustTemplate(
"return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));",
*codegenScope,
@ -179,6 +184,7 @@ class EventStreamUnmarshallerGenerator(
*codegenScope,
)
}
payloadOnly -> {
withBlock("let parsed = ", ";") {
renderParseProtocolPayload(unionMember)
@ -189,6 +195,7 @@ class EventStreamUnmarshallerGenerator(
*codegenScope,
)
}
else -> {
rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(unionStruct))
val payloadMember = unionStruct.members().firstOrNull { it.hasTrait<EventPayloadTrait>() }
@ -265,6 +272,7 @@ class EventStreamUnmarshallerGenerator(
is BlobShape -> {
rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope)
}
is StringShape -> {
rustTemplate(
"""
@ -275,6 +283,7 @@ class EventStreamUnmarshallerGenerator(
*codegenScope,
)
}
is UnionShape, is StructureShape -> {
renderParseProtocolPayload(member)
}
@ -312,6 +321,7 @@ class EventStreamUnmarshallerGenerator(
*codegenScope,
)
}
CodegenTarget.SERVER -> {}
}
@ -350,6 +360,7 @@ class EventStreamUnmarshallerGenerator(
)
}
}
CodegenTarget.SERVER -> {
val target = model.expectShape(member.target, StructureShape::class.java)
val parser = protocol.structuredDataParser(operationShape).errorParser(target)
@ -391,6 +402,7 @@ class EventStreamUnmarshallerGenerator(
CodegenTarget.CLIENT -> {
rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope)
}
CodegenTarget.SERVER -> {
rustTemplate(
"""

View File

@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticEventStreamUnionTrait
import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors
@ -49,7 +50,7 @@ class EventStreamErrorMarshallerGenerator(
} else {
symbolProvider.symbolForEventStreamError(unionShape)
}
private val eventStreamSerdeModule = RustModule.private("event_stream_serde")
private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule()
private val errorsShape = unionShape.expectTrait<SyntheticEventStreamUnionTrait>()
private val codegenScope = arrayOf(
"MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"),
@ -126,7 +127,7 @@ class EventStreamErrorMarshallerGenerator(
}
}
fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) {
private fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) {
val headerMembers = eventStruct.members().filter { it.hasTrait<EventHeaderTrait>() }
val payloadMember = eventStruct.members().firstOrNull { it.hasTrait<EventPayloadTrait>() }
for (member in headerMembers) {

View File

@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
@ -53,7 +54,7 @@ open class EventStreamMarshallerGenerator(
private val payloadContentType: String,
) {
private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig)
private val eventStreamSerdeModule = RustModule.private("event_stream_serde")
private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule()
private val codegenScope = arrayOf(
"MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"),
"Message" to smithyEventStream.resolve("frame::Message"),

View File

@ -5,26 +5,35 @@
package software.amazon.smithy.rust.codegen.core.testutil
import org.intellij.lang.annotations.Language
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.dq
internal object EventStreamMarshallTestCases {
internal fun RustWriter.writeMarshallTestCases(
object EventStreamMarshallTestCases {
fun RustWriter.writeMarshallTestCases(
testCase: EventStreamTestModels.TestCase,
generator: RuntimeType,
optionalBuilderInputs: Boolean,
) {
val generator = "crate::event_stream_serde::TestStreamMarshaller"
val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig)
.copy(scope = DependencyScope.Compile)
fun builderInput(
@Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}")
input: String,
vararg ctx: Pair<String, Any>,
): Writable = conditionalBuilderInput(input, conditional = optionalBuilderInputs, ctx = ctx)
rustTemplate(
"""
use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage};
use std::collections::HashMap;
use aws_smithy_types::{Blob, DateTime};
use crate::error::*;
use crate::model::*;
use #{validate_body};
@ -46,163 +55,192 @@ internal object EventStreamMarshallTestCases {
"MediaType" to protocolTestHelpers.toType().resolve("MediaType"),
)
unitTest(
"message_with_blob",
"""
let event = TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
)
unitTest("message_with_blob") {
rustTemplate(
"""
let event = TestStream::MessageWithBlob(
MessageWithBlob::builder().data(#{BlobInput:W}).build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
"BlobInput" to builderInput("Blob::new(&b\"hello, world!\"[..])"),
)
}
unitTest(
"message_with_string",
"""
let event = TestStream::MessageWithString(
MessageWithString::builder().data("hello, world!").build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
)
unitTest("message_with_string") {
rustTemplate(
"""
let event = TestStream::MessageWithString(
MessageWithString::builder().data(#{StringInput}).build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
"StringInput" to builderInput("\"hello, world!\""),
)
}
unitTest(
"message_with_struct",
"""
let event = TestStream::MessageWithStruct(
MessageWithStruct::builder().some_struct(
TestStruct::builder()
.some_string("hello")
.some_int(5)
.build()
).build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
unitTest("message_with_struct") {
rustTemplate(
"""
let event = TestStream::MessageWithStruct(
MessageWithStruct::builder().some_struct(#{StructInput}).build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validTestStruct.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
validate_body(
message.payload(),
${testCase.validTestStruct.dq()},
MediaType::from(${testCase.mediaType.dq()})
).unwrap();
""",
"StructInput" to
builderInput(
"""
TestStruct::builder()
.some_string(#{StringInput})
.some_int(#{IntInput})
.build()
""",
"IntInput" to builderInput("5"),
"StringInput" to builderInput("\"hello\""),
),
)
}
unitTest(
"message_with_union",
"""
let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build());
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
unitTest("message_with_union") {
rustTemplate(
"""
let event = TestStream::MessageWithUnion(MessageWithUnion::builder()
.some_union(#{UnionInput})
.build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validTestUnion.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
validate_body(
message.payload(),
${testCase.validTestUnion.dq()},
MediaType::from(${testCase.mediaType.dq()})
).unwrap();
""",
"UnionInput" to builderInput("TestUnion::Foo(\"hello\".into())"),
)
}
unitTest(
"message_with_headers",
"""
let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(DateTime::from_secs(5))
.build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b""[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into())))
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
assert_eq!(expected_message, actual_message);
""",
)
unitTest("message_with_headers") {
rustTemplate(
"""
let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(#{BlobInput})
.boolean(#{BooleanInput})
.byte(#{ByteInput})
.int(#{IntInput})
.long(#{LongInput})
.short(#{ShortInput})
.string(#{StringInput})
.timestamp(#{TimestampInput})
.build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b""[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into())))
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
assert_eq!(expected_message, actual_message);
""",
"BlobInput" to builderInput("Blob::new(&b\"test\"[..])"),
"BooleanInput" to builderInput("true"),
"ByteInput" to builderInput("55i8"),
"IntInput" to builderInput("100_000i32"),
"LongInput" to builderInput("9_000_000_000i64"),
"ShortInput" to builderInput("16_000i16"),
"StringInput" to builderInput("\"test\""),
"TimestampInput" to builderInput("DateTime::from_secs(5)"),
)
}
unitTest(
"message_with_header_and_payload",
"""
let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b"payload"[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into())))
.add_header(Header::new("header", HeaderValue::String("header".into())))
.add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into())));
assert_eq!(expected_message, actual_message);
""",
)
unitTest("message_with_header_and_payload") {
rustTemplate(
"""
let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header(#{HeaderInput})
.payload(#{PayloadInput})
.build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b"payload"[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into())))
.add_header(Header::new("header", HeaderValue::String("header".into())))
.add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into())));
assert_eq!(expected_message, actual_message);
""",
"HeaderInput" to builderInput("\"header\""),
"PayloadInput" to builderInput("Blob::new(&b\"payload\"[..])"),
)
}
unitTest(
"message_with_no_header_payload_traits",
"""
let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
unitTest("message_with_no_header_payload_traits") {
rustTemplate(
"""
let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(#{IntInput})
.some_string(#{StringInput})
.build()
);
let result = $generator::new().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validMessageWithNoHeaderPayloadTraits.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
validate_body(
message.payload(),
${testCase.validMessageWithNoHeaderPayloadTraits.dq()},
MediaType::from(${testCase.mediaType.dq()})
).unwrap();
""",
"IntInput" to builderInput("5"),
"StringInput" to builderInput("\"hello\""),
)
}
}
}

View File

@ -19,6 +19,7 @@ private fun fillInBaseModel(
): String = """
namespace test
use smithy.framework#ValidationException
use aws.protocols#$protocolName
union TestUnion {
@ -69,12 +70,20 @@ private fun fillInBaseModel(
MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits,
SomeError: SomeError,
}
structure TestStreamInputOutput { @httpPayload @required value: TestStream }
structure TestStreamInputOutput {
@required
@httpPayload
value: TestStream
}
@http(method: "POST", uri: "/test")
operation TestStreamOp {
input: TestStreamInputOutput,
output: TestStreamInputOutput,
errors: [SomeError],
errors: [SomeError, ValidationException],
}
$extraServiceAnnotations
@$protocolName
service TestService { version: "123", operations: [TestStreamOp] }
@ -92,6 +101,7 @@ object EventStreamTestModels {
data class TestCase(
val protocolShapeId: String,
val model: Model,
val mediaType: String,
val requestContentType: String,
val responseContentType: String,
val validTestStruct: String,
@ -111,7 +121,8 @@ object EventStreamTestModels {
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
requestContentType = "application/json",
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
@ -126,6 +137,7 @@ object EventStreamTestModels {
TestCase(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
mediaType = "application/x-amz-json-1.1",
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
validTestStruct = """{"someString":"hello","someInt":5}""",
@ -141,7 +153,8 @@ object EventStreamTestModels {
TestCase(
protocolShapeId = "aws.protocols#restXml",
model = restXml(),
requestContentType = "application/xml",
mediaType = "application/xml",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/xml",
validTestStruct = """
<TestStruct>

View File

@ -1,185 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.core.testutil
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.core.util.outputShape
data class TestEventStreamProject(
val model: Model,
val serviceShape: ServiceShape,
val operationShape: OperationShape,
val streamShape: UnionShape,
val symbolProvider: RustSymbolProvider,
val project: TestWriterDelegator,
)
enum class EventStreamTestVariety {
Marshall,
Unmarshall,
}
interface EventStreamTestRequirements<C : CodegenContext> {
/** Create a codegen context for the tests */
fun createCodegenContext(
model: Model,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): C
/** Render the event stream marshall/unmarshall code generator */
fun renderGenerator(
codegenContext: C,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType
/** Render a builder for the given shape */
fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: C,
shape: StructureShape,
)
/** Render an operation error for the given operation and error shapes */
fun renderOperationError(
writer: RustWriter,
model: Model,
symbolProvider: RustSymbolProvider,
operationOrEventStream: Shape,
)
/** Render an error struct and builder */
fun renderError(rustCrate: RustCrate, writer: RustWriter, codegenContext: C, shape: StructureShape)
}
object EventStreamTestTools {
fun <C : CodegenContext> setupTestCase(
testCase: EventStreamTestModels.TestCase,
requirements: EventStreamTestRequirements<C>,
codegenTarget: CodegenTarget,
variety: EventStreamTestVariety,
transformers: List<(Model) -> Model> = listOf(),
): TestWriterDelegator {
val model = (listOf(OperationNormalizer::transform, EventStreamNormalizer::transform) + transformers).fold(testCase.model) { model, transformer ->
transformer(model)
}
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val codegenContext = requirements.createCodegenContext(
model,
serviceShape,
ShapeId.from(testCase.protocolShapeId),
codegenTarget,
)
val test = generateTestProject(requirements, codegenContext, codegenTarget)
val protocol = testCase.protocolBuilder(codegenContext)
val generator = requirements.renderGenerator(codegenContext, test, protocol)
test.project.lib {
when (variety) {
EventStreamTestVariety.Marshall -> writeMarshallTestCases(testCase, generator)
EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator)
}
}
return test.project
}
private fun <C : CodegenContext> generateTestProject(
requirements: EventStreamTestRequirements<C>,
codegenContext: C,
codegenTarget: CodegenTarget,
): TestEventStreamProject {
val model = codegenContext.model
val symbolProvider = codegenContext.symbolProvider
val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape
val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape
val walker = DirectedWalker(model)
val project = TestWorkspace.testProject(symbolProvider)
val errors = model.serviceShapes
.flatMap { walker.walkShapes(it) }
.filterIsInstance<StructureShape>()
.filter { shape -> shape.hasTrait<ErrorTrait>() }
check(errors.isNotEmpty()) { "must have at least one error modeled" }
project.withModule(codegenContext.symbolProvider.moduleForShape(errors[0])) {
requirements.renderOperationError(this, model, symbolProvider, operationShape)
requirements.renderOperationError(this, model, symbolProvider, unionShape)
for (shape in errors) {
requirements.renderError(project, this, codegenContext, shape)
}
}
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
project.withModule(codegenContext.symbolProvider.moduleForShape(inputOutput)) {
recursivelyGenerateModels(project, model, symbolProvider, inputOutput, this, codegenTarget)
}
operationShape.outputShape(model).also { outputShape ->
outputShape.renderWithModelBuilder(model, symbolProvider, project)
}
return TestEventStreamProject(
model,
codegenContext.serviceShape,
operationShape,
unionShape,
symbolProvider,
project,
)
}
private fun recursivelyGenerateModels(
rustCrate: RustCrate,
model: Model,
symbolProvider: RustSymbolProvider,
shape: Shape,
writer: RustWriter,
mode: CodegenTarget,
) {
for (member in shape.members()) {
if (member.target.namespace == "smithy.api") {
continue
}
val target = model.expectShape(member.target)
when (target) {
is StructureShape -> target.renderWithModelBuilder(model, symbolProvider, rustCrate)
is UnionShape -> UnionGenerator(
model,
symbolProvider,
writer,
target,
renderUnknownVariant = mode.renderUnknownVariant(),
).render()
else -> TODO("EventStreamTestTools doesn't support rendering $target")
}
recursivelyGenerateModels(rustCrate, model, symbolProvider, target, writer, mode)
}
}
}

View File

@ -5,17 +5,22 @@
package software.amazon.smithy.rust.codegen.core.testutil
import org.intellij.lang.annotations.Language
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
internal object EventStreamUnmarshallTestCases {
internal fun RustWriter.writeUnmarshallTestCases(
object EventStreamUnmarshallTestCases {
fun RustWriter.writeUnmarshallTestCases(
testCase: EventStreamTestModels.TestCase,
codegenTarget: CodegenTarget,
generator: RuntimeType,
optionalBuilderInputs: Boolean = false,
) {
val generator = "crate::event_stream_serde::TestStreamUnmarshaller"
rust(
"""
use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage};
@ -53,202 +58,199 @@ internal object EventStreamUnmarshallTestCases {
""",
)
unitTest(
name = "message_with_blob",
test = """
unitTest("message_with_blob") {
rustTemplate(
"""
let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!");
let result = ${format(generator)}().unmarshall(&message);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
MessageWithBlob::builder().data(#{DataInput:W}).build()
),
expect_event(result.unwrap())
);
""",
)
if (codegenTarget == CodegenTarget.CLIENT) {
unitTest(
"unknown_message",
"""
let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!");
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::Unknown,
expect_event(result.unwrap())
);
""",
"DataInput" to conditionalBuilderInput(
"""
Blob::new(&b"hello, world!"[..])
""",
conditional = optionalBuilderInputs,
),
)
}
unitTest(
"message_with_string",
"""
let message = msg("event", "MessageWithString", "text/plain", b"hello, world!");
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()),
expect_event(result.unwrap())
);
""",
)
unitTest("message_with_string") {
rustTemplate(
"""
let message = msg("event", "MessageWithString", "text/plain", b"hello, world!");
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithString(MessageWithString::builder().data(#{DataInput}).build()),
expect_event(result.unwrap())
);
""",
"DataInput" to conditionalBuilderInput("\"hello, world!\"", conditional = optionalBuilderInputs),
)
}
unitTest(
"message_with_struct",
"""
let message = msg(
"event",
"MessageWithStruct",
"${testCase.responseContentType}",
br#"${testCase.validTestStruct}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(
unitTest("message_with_struct") {
rustTemplate(
"""
let message = msg(
"event",
"MessageWithStruct",
"${testCase.responseContentType}",
br##"${testCase.validTestStruct}"##
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(#{StructInput}).build()),
expect_event(result.unwrap())
);
""",
"StructInput" to conditionalBuilderInput(
"""
TestStruct::builder()
.some_string("hello")
.some_int(5)
.some_string(#{StringInput})
.some_int(#{IntInput})
.build()
).build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_union",
"""
let message = msg(
"event",
"MessageWithUnion",
"${testCase.responseContentType}",
br#"${testCase.validTestUnion}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_headers",
"""
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(DateTime::from_secs(5))
.build()
""",
conditional = optionalBuilderInputs,
"StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs),
"IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs),
),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_header_and_payload",
"""
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
.add_header(Header::new("header", HeaderValue::String("header".into())));
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
),
expect_event(result.unwrap())
);
""",
)
)
}
unitTest(
"message_with_no_header_payload_traits",
"""
let message = msg(
"event",
"MessageWithNoHeaderPayloadTraits",
"${testCase.responseContentType}",
br#"${testCase.validMessageWithNoHeaderPayloadTraits}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
),
expect_event(result.unwrap())
);
""",
)
unitTest("message_with_union") {
rustTemplate(
"""
let message = msg(
"event",
"MessageWithUnion",
"${testCase.responseContentType}",
br##"${testCase.validTestUnion}"##
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(#{UnionInput}).build()),
expect_event(result.unwrap())
);
""",
"UnionInput" to conditionalBuilderInput("TestUnion::Foo(\"hello\".into())", conditional = optionalBuilderInputs),
)
}
unitTest(
"some_error",
"""
let message = msg(
"exception",
"SomeError",
"${testCase.responseContentType}",
br#"${testCase.validSomeError}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap()) {
TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()),
kind => panic!("expected SomeError, but got {:?}", kind),
}
""",
)
unitTest("message_with_headers") {
rustTemplate(
"""
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(#{BlobInput})
.boolean(#{BoolInput})
.byte(#{ByteInput})
.int(#{IntInput})
.long(#{LongInput})
.short(#{ShortInput})
.string(#{StringInput})
.timestamp(#{TimestampInput})
.build()
),
expect_event(result.unwrap())
);
""",
"BlobInput" to conditionalBuilderInput("Blob::new(&b\"test\"[..])", conditional = optionalBuilderInputs),
"BoolInput" to conditionalBuilderInput("true", conditional = optionalBuilderInputs),
"ByteInput" to conditionalBuilderInput("55i8", conditional = optionalBuilderInputs),
"IntInput" to conditionalBuilderInput("100_000i32", conditional = optionalBuilderInputs),
"LongInput" to conditionalBuilderInput("9_000_000_000i64", conditional = optionalBuilderInputs),
"ShortInput" to conditionalBuilderInput("16_000i16", conditional = optionalBuilderInputs),
"StringInput" to conditionalBuilderInput("\"test\"", conditional = optionalBuilderInputs),
"TimestampInput" to conditionalBuilderInput("DateTime::from_secs(5)", conditional = optionalBuilderInputs),
)
}
if (codegenTarget == CodegenTarget.CLIENT) {
unitTest(
"error_metadata",
unitTest("message_with_header_and_payload") {
rustTemplate(
"""
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
.add_header(Header::new("header", HeaderValue::String("header".into())));
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header(#{HeaderInput})
.payload(#{PayloadInput})
.build()
),
expect_event(result.unwrap())
);
""",
"HeaderInput" to conditionalBuilderInput("\"header\"", conditional = optionalBuilderInputs),
"PayloadInput" to conditionalBuilderInput("Blob::new(&b\"payload\"[..])", conditional = optionalBuilderInputs),
)
}
unitTest("message_with_no_header_payload_traits") {
rustTemplate(
"""
let message = msg(
"event",
"MessageWithNoHeaderPayloadTraits",
"${testCase.responseContentType}",
br##"${testCase.validMessageWithNoHeaderPayloadTraits}"##
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(#{IntInput})
.some_string(#{StringInput})
.build()
),
expect_event(result.unwrap())
);
""",
"IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs),
"StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs),
)
}
unitTest("some_error") {
rustTemplate(
"""
let message = msg(
"exception",
"UnmodeledError",
"SomeError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
br##"${testCase.validSomeError}"##
);
let result = ${format(generator)}().unmarshall(&message);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap()) {
TestStreamError::Unhandled(err) => {
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
let expected = "message: \"unmodeled error\"";
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
}
kind => panic!("expected error metadata, but got {:?}", kind),
TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()),
#{AllowUnreachablePatterns:W}
kind => panic!("expected SomeError, but got {:?}", kind),
}
""",
"AllowUnreachablePatterns" to writable { Attribute.AllowUnreachablePatterns.render(this) },
)
}
@ -261,10 +263,21 @@ internal object EventStreamUnmarshallTestCases {
"wrong-content-type",
br#"${testCase.validTestStruct}"#
);
let result = ${format(generator)}().unmarshall(&message);
let result = $generator::new().unmarshall(&message);
assert!(result.is_err(), "expected error, got: {:?}", result);
assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be"));
""",
)
}
}
internal fun conditionalBuilderInput(
@Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") contents: String,
conditional: Boolean,
vararg ctx: Pair<String, Any>,
): Writable =
writable {
conditionalBlock("Some(", ".into())", conditional = conditional) {
rustTemplate(contents, *ctx)
}
}

View File

@ -1,121 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings
import java.util.stream.Stream
data class TestCase(
val eventStreamTestCase: EventStreamTestModels.TestCase,
val publicConstrainedTypes: Boolean,
) {
override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes"
}
class TestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
EventStreamTestModels.TEST_CASES
.flatMap { testCase ->
listOf(
TestCase(testCase, publicConstrainedTypes = false),
TestCase(testCase, publicConstrainedTypes = true),
)
}.map { Arguments.of(it) }.stream()
}
abstract class ServerEventStreamBaseRequirements : EventStreamTestRequirements<ServerCodegenContext> {
abstract val publicConstrainedTypes: Boolean
override fun createCodegenContext(
model: Model,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): ServerCodegenContext = serverTestCodegenContext(
model,
serviceShape,
serverTestRustSettings(
codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes),
),
protocolShapeId,
)
override fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ServerCodegenContext,
shape: StructureShape,
) {
val validationExceptionConversionGenerator = SmithyValidationExceptionConversionGenerator(codegenContext)
if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator).apply {
render(rustCrate, writer)
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
renderConvenienceMethod(writer)
}
}
} else {
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator).apply {
render(rustCrate, writer)
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
renderConvenienceMethod(writer)
}
}
}
}
override fun renderOperationError(
writer: RustWriter,
model: Model,
symbolProvider: RustSymbolProvider,
operationOrEventStream: Shape,
) {
ServerOperationErrorGenerator(model, symbolProvider, operationOrEventStream).render(writer)
}
override fun renderError(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ServerCodegenContext,
shape: StructureShape,
) {
StructureGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, listOf()).render()
ErrorImplGenerator(
codegenContext.model,
codegenContext.symbolProvider,
writer,
shape,
shape.getTrait()!!,
listOf(),
).render(CodegenTarget.SERVER)
renderBuilderForShape(rustCrate, writer, codegenContext, shape)
}
}

View File

@ -5,49 +5,43 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
import java.util.stream.Stream
class ServerEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: TestCase) {
val testProject = EventStreamTestTools.setupTestCase(
testCase.eventStreamTestCase,
object : ServerEventStreamBaseRequirements() {
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes
override fun renderGenerator(
codegenContext: ServerCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType {
return EventStreamMarshallerGenerator(
project.model,
CodegenTarget.SERVER,
TestRuntimeConfig,
project.symbolProvider,
project.streamShape,
protocol.structuredDataSerializer(project.operationShape),
testCase.eventStreamTestCase.requestContentType,
).render()
}
},
CodegenTarget.SERVER,
EventStreamTestVariety.Marshall,
)
testProject.renderInlineMemoryModules()
testProject.compileAndTest()
serverIntegrationTest(testCase.eventStreamTestCase.model) { _, rustCrate ->
rustCrate.testModule {
writeMarshallTestCases(testCase.eventStreamTestCase, optionalBuilderInputs = true)
}
}
}
}
data class TestCase(
val eventStreamTestCase: EventStreamTestModels.TestCase,
val publicConstrainedTypes: Boolean,
) {
override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes"
}
class TestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
EventStreamTestModels.TEST_CASES
.flatMap { testCase ->
listOf(
TestCase(testCase, publicConstrainedTypes = false),
TestCase(testCase, publicConstrainedTypes = true),
)
}.map { Arguments.of(it) }.stream()
}

View File

@ -7,21 +7,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
class ServerEventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ -33,45 +22,16 @@ class ServerEventStreamUnmarshallerGeneratorTest {
return
}
val testProject = EventStreamTestTools.setupTestCase(
testCase.eventStreamTestCase,
object : ServerEventStreamBaseRequirements() {
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes
override fun renderGenerator(
codegenContext: ServerCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType {
return EventStreamUnmarshallerGenerator(
protocol,
codegenContext,
project.operationShape,
project.streamShape,
).render()
}
// TODO(https://github.com/awslabs/smithy-rs/issues/1442): Delete this function override to use the correct builder from the parent class
override fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ServerCodegenContext,
shape: StructureShape,
) {
rustCrate.withModule(codegenContext.symbolProvider.moduleForBuilder(shape)) {
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, emptyList()).render(this)
}
rustCrate.moduleFor(shape) {
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
BuilderGenerator.renderConvenienceMethod(this, codegenContext.symbolProvider, shape)
}
}
}
},
CodegenTarget.SERVER,
EventStreamTestVariety.Unmarshall,
transformers = listOf(ConstrainedMemberTransform::transform),
)
testProject.compileAndTest()
serverIntegrationTest(
testCase.eventStreamTestCase.model,
IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true),
) { _, rustCrate ->
rustCrate.testModule {
writeUnmarshallTestCases(
testCase.eventStreamTestCase,
optionalBuilderInputs = true,
)
}
}
}
}

View File

@ -14,6 +14,7 @@ pub fn body_is_error(body: &[u8]) -> Result<bool, XmlDecodeError> {
Ok(scoped.start_el().matches("ErrorResponse"))
}
#[allow(dead_code)]
pub fn parse_error_metadata(body: &[u8]) -> Result<ErrorMetadataBuilder, XmlDecodeError> {
let mut doc = Document::try_from(body)?;
let mut root = doc.root_element()?;