mirror of https://github.com/smithy-lang/smithy-rs
Refactor event stream tests with `{client,server}IntegrationTest`s (#2342)
* Refactor `ClientEventStreamUnmarshallerGeneratorTest` to use `clientIntegrationTest` (WIP) * Refactor `ClientEventStreamUnmarshallerGeneratorTest` with `clientIntegrationTest` * Refactor `ClientEventStreamUnmarshallerGeneratorTest` to use generic test cases * Start refactoring `ServerEventStreamUnmarshallerGeneratorTest` * Make `ServerEventStreamUnmarshallerGeneratorTest` tests work * Uncomment other test models * Allow unused on `parse_generic_error` * Rename `ServerEventStreamUnmarshallerGeneratorTest` * Make `EventStreamUnmarshallTestCases` codegenTarget-agnostic * Refactor `ClientEventStreamMarshallerGeneratorTest`: Tests run but fail * Refactor `ServerEventStreamMarshallerGeneratorTest` * Move `.into()` calls to `conditionalBuilderInput` * Add "context" to TODO * Fix client unmarshall tests * Fix clippy lint * Fix more clippy lints * Add docs for `event_stream_serde` module * Fix client tests * Remove `#[allow(missing_docs)]` from event stream module * Remove unused `EventStreamTestTools` * Add `smithy-validation-model` test dep to `codegen-client` * Temporarily add docs to make tests compile * Undo change in model * Make event stream unmarshaller tests a unit test * Remove unused code * Make `ServerEventStreamUnmarshallerGeneratorTest` a unit test * Make `ServerEventStreamMarshallerGeneratorTest` a unit test * Make `ServerEventStreamMarshallerGeneratorTest` pass * Make remaining tests non-integration tests * Make event stream serde module private again * Remove unnecessary clippy allowances * Remove clippy allowance * Remove docs for `event_stream_serde` module * Remove docs for `$unmarshallerTypeName::new` * Remove more unnecessary docs * Remove more superfluous docs * Undo unnecessary diffs * Uncomment last test * Make `conditionalBuilderInput` internal
This commit is contained in:
parent
72df8440c0
commit
c3ae6f7eaf
|
@ -28,6 +28,10 @@ dependencies {
|
|||
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
|
||||
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
|
||||
implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion")
|
||||
|
||||
// `smithy.framework#ValidationException` is defined here, which is used in event stream
|
||||
// marshalling/unmarshalling tests.
|
||||
testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion")
|
||||
}
|
||||
|
||||
tasks.compileKotlin {
|
||||
|
|
|
@ -1,98 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
|
||||
|
||||
import org.junit.jupiter.api.extension.ExtensionContext
|
||||
import org.junit.jupiter.params.provider.Arguments
|
||||
import org.junit.jupiter.params.provider.ArgumentsProvider
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.traits.ErrorTrait
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings
|
||||
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
|
||||
import software.amazon.smithy.rust.codegen.core.util.expectTrait
|
||||
import java.util.stream.Stream
|
||||
|
||||
class TestCasesProvider : ArgumentsProvider {
|
||||
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
|
||||
EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream()
|
||||
}
|
||||
|
||||
abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<ClientCodegenContext> {
|
||||
override fun createCodegenContext(
|
||||
model: Model,
|
||||
serviceShape: ServiceShape,
|
||||
protocolShapeId: ShapeId,
|
||||
codegenTarget: CodegenTarget,
|
||||
): ClientCodegenContext = ClientCodegenContext(
|
||||
model,
|
||||
testSymbolProvider(model),
|
||||
serviceShape,
|
||||
protocolShapeId,
|
||||
testClientRustSettings(),
|
||||
CombinedClientCodegenDecorator(emptyList()),
|
||||
)
|
||||
|
||||
override fun renderBuilderForShape(
|
||||
rustCrate: RustCrate,
|
||||
writer: RustWriter,
|
||||
codegenContext: ClientCodegenContext,
|
||||
shape: StructureShape,
|
||||
) {
|
||||
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, emptyList()).apply {
|
||||
render(writer)
|
||||
}
|
||||
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
|
||||
BuilderGenerator.renderConvenienceMethod(writer, codegenContext.symbolProvider, shape)
|
||||
}
|
||||
}
|
||||
|
||||
override fun renderOperationError(
|
||||
writer: RustWriter,
|
||||
model: Model,
|
||||
symbolProvider: RustSymbolProvider,
|
||||
operationOrEventStream: Shape,
|
||||
) {
|
||||
OperationErrorGenerator(model, symbolProvider, operationOrEventStream, emptyList()).render(writer)
|
||||
}
|
||||
|
||||
override fun renderError(
|
||||
rustCrate: RustCrate,
|
||||
writer: RustWriter,
|
||||
codegenContext: ClientCodegenContext,
|
||||
shape: StructureShape,
|
||||
) {
|
||||
val errorTrait = shape.expectTrait<ErrorTrait>()
|
||||
val errorGenerator = ErrorGenerator(
|
||||
codegenContext.model,
|
||||
codegenContext.symbolProvider,
|
||||
shape,
|
||||
errorTrait,
|
||||
emptyList(),
|
||||
)
|
||||
rustCrate.useShapeWriter(shape) {
|
||||
errorGenerator.renderStruct(this)
|
||||
}
|
||||
rustCrate.withModule(codegenContext.symbolProvider.moduleForBuilder(shape)) {
|
||||
errorGenerator.renderBuilder(this)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -5,43 +5,30 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
|
||||
|
||||
import org.junit.jupiter.api.extension.ExtensionContext
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.Arguments
|
||||
import org.junit.jupiter.params.provider.ArgumentsProvider
|
||||
import org.junit.jupiter.params.provider.ArgumentsSource
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.testModule
|
||||
import java.util.stream.Stream
|
||||
|
||||
class ClientEventStreamMarshallerGeneratorTest {
|
||||
@ParameterizedTest
|
||||
@ArgumentsSource(TestCasesProvider::class)
|
||||
fun test(testCase: EventStreamTestModels.TestCase) {
|
||||
EventStreamTestTools.setupTestCase(
|
||||
testCase,
|
||||
object : ClientEventStreamBaseRequirements() {
|
||||
override fun renderGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
project: TestEventStreamProject,
|
||||
protocol: Protocol,
|
||||
): RuntimeType = EventStreamMarshallerGenerator(
|
||||
project.model,
|
||||
CodegenTarget.CLIENT,
|
||||
TestRuntimeConfig,
|
||||
project.symbolProvider,
|
||||
project.streamShape,
|
||||
protocol.structuredDataSerializer(project.operationShape),
|
||||
testCase.requestContentType,
|
||||
).render()
|
||||
},
|
||||
CodegenTarget.CLIENT,
|
||||
EventStreamTestVariety.Marshall,
|
||||
).compileAndTest()
|
||||
clientIntegrationTest(testCase.model) { _, rustCrate ->
|
||||
rustCrate.testModule {
|
||||
writeMarshallTestCases(testCase, optionalBuilderInputs = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestCasesProvider : ArgumentsProvider {
|
||||
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
|
||||
EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream()
|
||||
}
|
||||
|
|
|
@ -7,39 +7,60 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
|
|||
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.ArgumentsSource
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.testModule
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
|
||||
|
||||
class ClientEventStreamUnmarshallerGeneratorTest {
|
||||
@ParameterizedTest
|
||||
@ArgumentsSource(TestCasesProvider::class)
|
||||
fun test(testCase: EventStreamTestModels.TestCase) {
|
||||
EventStreamTestTools.setupTestCase(
|
||||
testCase,
|
||||
object : ClientEventStreamBaseRequirements() {
|
||||
override fun renderGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
project: TestEventStreamProject,
|
||||
protocol: Protocol,
|
||||
): RuntimeType {
|
||||
return EventStreamUnmarshallerGenerator(
|
||||
protocol,
|
||||
codegenContext,
|
||||
project.operationShape,
|
||||
project.streamShape,
|
||||
).render()
|
||||
}
|
||||
},
|
||||
CodegenTarget.CLIENT,
|
||||
EventStreamTestVariety.Unmarshall,
|
||||
).compileAndTest()
|
||||
clientIntegrationTest(
|
||||
testCase.model,
|
||||
IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true),
|
||||
) { _, rustCrate ->
|
||||
val generator = "crate::event_stream_serde::TestStreamUnmarshaller"
|
||||
|
||||
rustCrate.testModule {
|
||||
rust("##![allow(unused_imports, dead_code)]")
|
||||
writeUnmarshallTestCases(testCase, optionalBuilderInputs = false)
|
||||
|
||||
unitTest(
|
||||
"unknown_message",
|
||||
"""
|
||||
let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!");
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert!(expect_event(result.unwrap()).is_unknown());
|
||||
""",
|
||||
)
|
||||
|
||||
unitTest(
|
||||
"generic_error",
|
||||
"""
|
||||
let message = msg(
|
||||
"exception",
|
||||
"UnmodeledError",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validUnmodeledError}"#
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
match expect_error(result.unwrap()) {
|
||||
TestStreamError::Unhandled(err) => {
|
||||
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
|
||||
let expected = "message: \"unmodeled error\"";
|
||||
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
|
||||
}
|
||||
kind => panic!("expected generic error, but got {:?}", kind),
|
||||
}
|
||||
""",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -470,9 +470,11 @@ class Attribute(val inner: Writable) {
|
|||
val AllowDeprecated = Attribute(allow("deprecated"))
|
||||
val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns"))
|
||||
val AllowUnreachableCode = Attribute(allow("unreachable_code"))
|
||||
val AllowUnreachablePatterns = Attribute(allow("unreachable_patterns"))
|
||||
val AllowUnusedImports = Attribute(allow("unused_imports"))
|
||||
val AllowUnusedMut = Attribute(allow("unused_mut"))
|
||||
val AllowUnusedVariables = Attribute(allow("unused_variables"))
|
||||
val AllowMissingDocs = Attribute(allow("missing_docs"))
|
||||
val CfgTest = Attribute(cfg("test"))
|
||||
val DenyMissingDocs = Attribute(deny("missing_docs"))
|
||||
val DocHidden = Attribute(doc("hidden"))
|
||||
|
|
|
@ -116,7 +116,7 @@ private fun <T : AbstractCodeWriter<T>, U> T.withTemplate(
|
|||
* This enables conditionally wrapping a block in a prefix/suffix, e.g.
|
||||
*
|
||||
* ```
|
||||
* writer.withBlock("Some(", ")", conditional = symbol.isOptional()) {
|
||||
* writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) {
|
||||
* write("symbolValue")
|
||||
* }
|
||||
* ```
|
||||
|
|
|
@ -43,6 +43,9 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait
|
|||
import software.amazon.smithy.rust.codegen.core.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
|
||||
|
||||
fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule =
|
||||
private("event_stream_serde")
|
||||
|
||||
class EventStreamUnmarshallerGenerator(
|
||||
private val protocol: Protocol,
|
||||
codegenContext: CodegenContext,
|
||||
|
@ -60,7 +63,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
symbolProvider.symbolForEventStreamError(unionShape)
|
||||
}
|
||||
private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig)
|
||||
private val eventStreamSerdeModule = RustModule.private("event_stream_serde")
|
||||
private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule()
|
||||
private val codegenScope = arrayOf(
|
||||
"Blob" to RuntimeType.blob(runtimeConfig),
|
||||
"expect_fns" to smithyEventStream.resolve("smithy"),
|
||||
|
@ -84,15 +87,16 @@ class EventStreamUnmarshallerGenerator(
|
|||
}
|
||||
|
||||
private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) {
|
||||
val unmarshallerTypeName = unmarshallerType.name
|
||||
rust(
|
||||
"""
|
||||
##[non_exhaustive]
|
||||
##[derive(Debug)]
|
||||
pub struct ${unmarshallerType.name};
|
||||
pub struct $unmarshallerTypeName;
|
||||
|
||||
impl ${unmarshallerType.name} {
|
||||
impl $unmarshallerTypeName {
|
||||
pub fn new() -> Self {
|
||||
${unmarshallerType.name}
|
||||
$unmarshallerTypeName
|
||||
}
|
||||
}
|
||||
""",
|
||||
|
@ -154,6 +158,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
"Output" to unionSymbol,
|
||||
*codegenScope,
|
||||
)
|
||||
|
||||
false -> rustTemplate(
|
||||
"return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));",
|
||||
*codegenScope,
|
||||
|
@ -179,6 +184,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
*codegenScope,
|
||||
)
|
||||
}
|
||||
|
||||
payloadOnly -> {
|
||||
withBlock("let parsed = ", ";") {
|
||||
renderParseProtocolPayload(unionMember)
|
||||
|
@ -189,6 +195,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
*codegenScope,
|
||||
)
|
||||
}
|
||||
|
||||
else -> {
|
||||
rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(unionStruct))
|
||||
val payloadMember = unionStruct.members().firstOrNull { it.hasTrait<EventPayloadTrait>() }
|
||||
|
@ -265,6 +272,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
is BlobShape -> {
|
||||
rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope)
|
||||
}
|
||||
|
||||
is StringShape -> {
|
||||
rustTemplate(
|
||||
"""
|
||||
|
@ -275,6 +283,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
*codegenScope,
|
||||
)
|
||||
}
|
||||
|
||||
is UnionShape, is StructureShape -> {
|
||||
renderParseProtocolPayload(member)
|
||||
}
|
||||
|
@ -312,6 +321,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
*codegenScope,
|
||||
)
|
||||
}
|
||||
|
||||
CodegenTarget.SERVER -> {}
|
||||
}
|
||||
|
||||
|
@ -350,6 +360,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
CodegenTarget.SERVER -> {
|
||||
val target = model.expectShape(member.target, StructureShape::class.java)
|
||||
val parser = protocol.structuredDataParser(operationShape).errorParser(target)
|
||||
|
@ -391,6 +402,7 @@ class EventStreamUnmarshallerGenerator(
|
|||
CodegenTarget.CLIENT -> {
|
||||
rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope)
|
||||
}
|
||||
|
||||
CodegenTarget.SERVER -> {
|
||||
rustTemplate(
|
||||
"""
|
||||
|
|
|
@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.rustType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticEventStreamUnionTrait
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors
|
||||
|
@ -49,7 +50,7 @@ class EventStreamErrorMarshallerGenerator(
|
|||
} else {
|
||||
symbolProvider.symbolForEventStreamError(unionShape)
|
||||
}
|
||||
private val eventStreamSerdeModule = RustModule.private("event_stream_serde")
|
||||
private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule()
|
||||
private val errorsShape = unionShape.expectTrait<SyntheticEventStreamUnionTrait>()
|
||||
private val codegenScope = arrayOf(
|
||||
"MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"),
|
||||
|
@ -126,7 +127,7 @@ class EventStreamErrorMarshallerGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) {
|
||||
private fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) {
|
||||
val headerMembers = eventStruct.members().filter { it.hasTrait<EventHeaderTrait>() }
|
||||
val payloadMember = eventStruct.members().firstOrNull { it.hasTrait<EventPayloadTrait>() }
|
||||
for (member in headerMembers) {
|
||||
|
|
|
@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.rustType
|
||||
import software.amazon.smithy.rust.codegen.core.util.dq
|
||||
import software.amazon.smithy.rust.codegen.core.util.hasTrait
|
||||
|
@ -53,7 +54,7 @@ open class EventStreamMarshallerGenerator(
|
|||
private val payloadContentType: String,
|
||||
) {
|
||||
private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig)
|
||||
private val eventStreamSerdeModule = RustModule.private("event_stream_serde")
|
||||
private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule()
|
||||
private val codegenScope = arrayOf(
|
||||
"MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"),
|
||||
"Message" to smithyEventStream.resolve("frame::Message"),
|
||||
|
|
|
@ -5,26 +5,35 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.core.testutil
|
||||
|
||||
import org.intellij.lang.annotations.Language
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.util.dq
|
||||
|
||||
internal object EventStreamMarshallTestCases {
|
||||
internal fun RustWriter.writeMarshallTestCases(
|
||||
object EventStreamMarshallTestCases {
|
||||
fun RustWriter.writeMarshallTestCases(
|
||||
testCase: EventStreamTestModels.TestCase,
|
||||
generator: RuntimeType,
|
||||
optionalBuilderInputs: Boolean,
|
||||
) {
|
||||
val generator = "crate::event_stream_serde::TestStreamMarshaller"
|
||||
|
||||
val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig)
|
||||
.copy(scope = DependencyScope.Compile)
|
||||
|
||||
fun builderInput(
|
||||
@Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}")
|
||||
input: String,
|
||||
vararg ctx: Pair<String, Any>,
|
||||
): Writable = conditionalBuilderInput(input, conditional = optionalBuilderInputs, ctx = ctx)
|
||||
|
||||
rustTemplate(
|
||||
"""
|
||||
use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage};
|
||||
use std::collections::HashMap;
|
||||
use aws_smithy_types::{Blob, DateTime};
|
||||
use crate::error::*;
|
||||
use crate::model::*;
|
||||
|
||||
use #{validate_body};
|
||||
|
@ -46,163 +55,192 @@ internal object EventStreamMarshallTestCases {
|
|||
"MediaType" to protocolTestHelpers.toType().resolve("MediaType"),
|
||||
)
|
||||
|
||||
unitTest(
|
||||
"message_with_blob",
|
||||
"""
|
||||
let event = TestStream::MessageWithBlob(
|
||||
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
|
||||
);
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap());
|
||||
assert_eq!(&b"hello, world!"[..], message.payload());
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_blob") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithBlob(
|
||||
MessageWithBlob::builder().data(#{BlobInput:W}).build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap());
|
||||
assert_eq!(&b"hello, world!"[..], message.payload());
|
||||
""",
|
||||
"BlobInput" to builderInput("Blob::new(&b\"hello, world!\"[..])"),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_string",
|
||||
"""
|
||||
let event = TestStream::MessageWithString(
|
||||
MessageWithString::builder().data("hello, world!").build()
|
||||
);
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap());
|
||||
assert_eq!(&b"hello, world!"[..], message.payload());
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_string") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithString(
|
||||
MessageWithString::builder().data(#{StringInput}).build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap());
|
||||
assert_eq!(&b"hello, world!"[..], message.payload());
|
||||
""",
|
||||
"StringInput" to builderInput("\"hello, world!\""),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_struct",
|
||||
"""
|
||||
let event = TestStream::MessageWithStruct(
|
||||
MessageWithStruct::builder().some_struct(
|
||||
TestStruct::builder()
|
||||
.some_string("hello")
|
||||
.some_int(5)
|
||||
.build()
|
||||
).build()
|
||||
);
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
|
||||
unitTest("message_with_struct") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithStruct(
|
||||
MessageWithStruct::builder().some_struct(#{StructInput}).build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
|
||||
|
||||
validate_body(
|
||||
message.payload(),
|
||||
${testCase.validTestStruct.dq()},
|
||||
MediaType::from(${testCase.requestContentType.dq()})
|
||||
).unwrap();
|
||||
""",
|
||||
)
|
||||
validate_body(
|
||||
message.payload(),
|
||||
${testCase.validTestStruct.dq()},
|
||||
MediaType::from(${testCase.mediaType.dq()})
|
||||
).unwrap();
|
||||
""",
|
||||
"StructInput" to
|
||||
builderInput(
|
||||
"""
|
||||
TestStruct::builder()
|
||||
.some_string(#{StringInput})
|
||||
.some_int(#{IntInput})
|
||||
.build()
|
||||
""",
|
||||
"IntInput" to builderInput("5"),
|
||||
"StringInput" to builderInput("\"hello\""),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_union",
|
||||
"""
|
||||
let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
|
||||
TestUnion::Foo("hello".into())
|
||||
).build());
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
|
||||
unitTest("message_with_union") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithUnion(MessageWithUnion::builder()
|
||||
.some_union(#{UnionInput})
|
||||
.build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
|
||||
|
||||
validate_body(
|
||||
message.payload(),
|
||||
${testCase.validTestUnion.dq()},
|
||||
MediaType::from(${testCase.requestContentType.dq()})
|
||||
).unwrap();
|
||||
""",
|
||||
)
|
||||
validate_body(
|
||||
message.payload(),
|
||||
${testCase.validTestUnion.dq()},
|
||||
MediaType::from(${testCase.mediaType.dq()})
|
||||
).unwrap();
|
||||
""",
|
||||
"UnionInput" to builderInput("TestUnion::Foo(\"hello\".into())"),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_headers",
|
||||
"""
|
||||
let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder()
|
||||
.blob(Blob::new(&b"test"[..]))
|
||||
.boolean(true)
|
||||
.byte(55i8)
|
||||
.int(100_000i32)
|
||||
.long(9_000_000_000i64)
|
||||
.short(16_000i16)
|
||||
.string("test")
|
||||
.timestamp(DateTime::from_secs(5))
|
||||
.build()
|
||||
);
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let actual_message = result.unwrap();
|
||||
let expected_message = Message::new(&b""[..])
|
||||
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
|
||||
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into())))
|
||||
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
|
||||
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
|
||||
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
|
||||
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
|
||||
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
|
||||
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
|
||||
.add_header(Header::new("string", HeaderValue::String("test".into())))
|
||||
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
|
||||
assert_eq!(expected_message, actual_message);
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_headers") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder()
|
||||
.blob(#{BlobInput})
|
||||
.boolean(#{BooleanInput})
|
||||
.byte(#{ByteInput})
|
||||
.int(#{IntInput})
|
||||
.long(#{LongInput})
|
||||
.short(#{ShortInput})
|
||||
.string(#{StringInput})
|
||||
.timestamp(#{TimestampInput})
|
||||
.build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let actual_message = result.unwrap();
|
||||
let expected_message = Message::new(&b""[..])
|
||||
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
|
||||
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into())))
|
||||
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
|
||||
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
|
||||
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
|
||||
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
|
||||
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
|
||||
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
|
||||
.add_header(Header::new("string", HeaderValue::String("test".into())))
|
||||
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
|
||||
assert_eq!(expected_message, actual_message);
|
||||
""",
|
||||
"BlobInput" to builderInput("Blob::new(&b\"test\"[..])"),
|
||||
"BooleanInput" to builderInput("true"),
|
||||
"ByteInput" to builderInput("55i8"),
|
||||
"IntInput" to builderInput("100_000i32"),
|
||||
"LongInput" to builderInput("9_000_000_000i64"),
|
||||
"ShortInput" to builderInput("16_000i16"),
|
||||
"StringInput" to builderInput("\"test\""),
|
||||
"TimestampInput" to builderInput("DateTime::from_secs(5)"),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_header_and_payload",
|
||||
"""
|
||||
let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
|
||||
.header("header")
|
||||
.payload(Blob::new(&b"payload"[..]))
|
||||
.build()
|
||||
);
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let actual_message = result.unwrap();
|
||||
let expected_message = Message::new(&b"payload"[..])
|
||||
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
|
||||
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into())))
|
||||
.add_header(Header::new("header", HeaderValue::String("header".into())))
|
||||
.add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into())));
|
||||
assert_eq!(expected_message, actual_message);
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_header_and_payload") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
|
||||
.header(#{HeaderInput})
|
||||
.payload(#{PayloadInput})
|
||||
.build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let actual_message = result.unwrap();
|
||||
let expected_message = Message::new(&b"payload"[..])
|
||||
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
|
||||
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into())))
|
||||
.add_header(Header::new("header", HeaderValue::String("header".into())))
|
||||
.add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into())));
|
||||
assert_eq!(expected_message, actual_message);
|
||||
""",
|
||||
"HeaderInput" to builderInput("\"header\""),
|
||||
"PayloadInput" to builderInput("Blob::new(&b\"payload\"[..])"),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_no_header_payload_traits",
|
||||
"""
|
||||
let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
|
||||
.some_int(5)
|
||||
.some_string("hello")
|
||||
.build()
|
||||
);
|
||||
let result = ${format(generator)}().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
|
||||
unitTest("message_with_no_header_payload_traits") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
|
||||
.some_int(#{IntInput})
|
||||
.some_string(#{StringInput})
|
||||
.build()
|
||||
);
|
||||
let result = $generator::new().marshall(event);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
let message = result.unwrap();
|
||||
let headers = headers_to_map(message.headers());
|
||||
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
|
||||
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
|
||||
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
|
||||
|
||||
validate_body(
|
||||
message.payload(),
|
||||
${testCase.validMessageWithNoHeaderPayloadTraits.dq()},
|
||||
MediaType::from(${testCase.requestContentType.dq()})
|
||||
).unwrap();
|
||||
""",
|
||||
)
|
||||
validate_body(
|
||||
message.payload(),
|
||||
${testCase.validMessageWithNoHeaderPayloadTraits.dq()},
|
||||
MediaType::from(${testCase.mediaType.dq()})
|
||||
).unwrap();
|
||||
""",
|
||||
"IntInput" to builderInput("5"),
|
||||
"StringInput" to builderInput("\"hello\""),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ private fun fillInBaseModel(
|
|||
): String = """
|
||||
namespace test
|
||||
|
||||
use smithy.framework#ValidationException
|
||||
use aws.protocols#$protocolName
|
||||
|
||||
union TestUnion {
|
||||
|
@ -69,12 +70,20 @@ private fun fillInBaseModel(
|
|||
MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits,
|
||||
SomeError: SomeError,
|
||||
}
|
||||
structure TestStreamInputOutput { @httpPayload @required value: TestStream }
|
||||
|
||||
structure TestStreamInputOutput {
|
||||
@required
|
||||
@httpPayload
|
||||
value: TestStream
|
||||
}
|
||||
|
||||
@http(method: "POST", uri: "/test")
|
||||
operation TestStreamOp {
|
||||
input: TestStreamInputOutput,
|
||||
output: TestStreamInputOutput,
|
||||
errors: [SomeError],
|
||||
errors: [SomeError, ValidationException],
|
||||
}
|
||||
|
||||
$extraServiceAnnotations
|
||||
@$protocolName
|
||||
service TestService { version: "123", operations: [TestStreamOp] }
|
||||
|
@ -92,6 +101,7 @@ object EventStreamTestModels {
|
|||
data class TestCase(
|
||||
val protocolShapeId: String,
|
||||
val model: Model,
|
||||
val mediaType: String,
|
||||
val requestContentType: String,
|
||||
val responseContentType: String,
|
||||
val validTestStruct: String,
|
||||
|
@ -111,7 +121,8 @@ object EventStreamTestModels {
|
|||
TestCase(
|
||||
protocolShapeId = "aws.protocols#restJson1",
|
||||
model = restJson1(),
|
||||
requestContentType = "application/json",
|
||||
mediaType = "application/json",
|
||||
requestContentType = "application/vnd.amazon.eventstream",
|
||||
responseContentType = "application/json",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
|
||||
|
@ -126,6 +137,7 @@ object EventStreamTestModels {
|
|||
TestCase(
|
||||
protocolShapeId = "aws.protocols#awsJson1_1",
|
||||
model = awsJson11(),
|
||||
mediaType = "application/x-amz-json-1.1",
|
||||
requestContentType = "application/x-amz-json-1.1",
|
||||
responseContentType = "application/x-amz-json-1.1",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
|
@ -141,7 +153,8 @@ object EventStreamTestModels {
|
|||
TestCase(
|
||||
protocolShapeId = "aws.protocols#restXml",
|
||||
model = restXml(),
|
||||
requestContentType = "application/xml",
|
||||
mediaType = "application/xml",
|
||||
requestContentType = "application/vnd.amazon.eventstream",
|
||||
responseContentType = "application/xml",
|
||||
validTestStruct = """
|
||||
<TestStruct>
|
||||
|
|
|
@ -1,185 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.core.testutil
|
||||
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.shapes.UnionShape
|
||||
import software.amazon.smithy.model.traits.ErrorTrait
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.lookup
|
||||
import software.amazon.smithy.rust.codegen.core.util.outputShape
|
||||
|
||||
data class TestEventStreamProject(
|
||||
val model: Model,
|
||||
val serviceShape: ServiceShape,
|
||||
val operationShape: OperationShape,
|
||||
val streamShape: UnionShape,
|
||||
val symbolProvider: RustSymbolProvider,
|
||||
val project: TestWriterDelegator,
|
||||
)
|
||||
|
||||
enum class EventStreamTestVariety {
|
||||
Marshall,
|
||||
Unmarshall,
|
||||
}
|
||||
|
||||
interface EventStreamTestRequirements<C : CodegenContext> {
|
||||
/** Create a codegen context for the tests */
|
||||
fun createCodegenContext(
|
||||
model: Model,
|
||||
serviceShape: ServiceShape,
|
||||
protocolShapeId: ShapeId,
|
||||
codegenTarget: CodegenTarget,
|
||||
): C
|
||||
|
||||
/** Render the event stream marshall/unmarshall code generator */
|
||||
fun renderGenerator(
|
||||
codegenContext: C,
|
||||
project: TestEventStreamProject,
|
||||
protocol: Protocol,
|
||||
): RuntimeType
|
||||
|
||||
/** Render a builder for the given shape */
|
||||
fun renderBuilderForShape(
|
||||
rustCrate: RustCrate,
|
||||
writer: RustWriter,
|
||||
codegenContext: C,
|
||||
shape: StructureShape,
|
||||
)
|
||||
|
||||
/** Render an operation error for the given operation and error shapes */
|
||||
fun renderOperationError(
|
||||
writer: RustWriter,
|
||||
model: Model,
|
||||
symbolProvider: RustSymbolProvider,
|
||||
operationOrEventStream: Shape,
|
||||
)
|
||||
|
||||
/** Render an error struct and builder */
|
||||
fun renderError(rustCrate: RustCrate, writer: RustWriter, codegenContext: C, shape: StructureShape)
|
||||
}
|
||||
|
||||
object EventStreamTestTools {
|
||||
fun <C : CodegenContext> setupTestCase(
|
||||
testCase: EventStreamTestModels.TestCase,
|
||||
requirements: EventStreamTestRequirements<C>,
|
||||
codegenTarget: CodegenTarget,
|
||||
variety: EventStreamTestVariety,
|
||||
transformers: List<(Model) -> Model> = listOf(),
|
||||
): TestWriterDelegator {
|
||||
val model = (listOf(OperationNormalizer::transform, EventStreamNormalizer::transform) + transformers).fold(testCase.model) { model, transformer ->
|
||||
transformer(model)
|
||||
}
|
||||
|
||||
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
|
||||
val codegenContext = requirements.createCodegenContext(
|
||||
model,
|
||||
serviceShape,
|
||||
ShapeId.from(testCase.protocolShapeId),
|
||||
codegenTarget,
|
||||
)
|
||||
val test = generateTestProject(requirements, codegenContext, codegenTarget)
|
||||
val protocol = testCase.protocolBuilder(codegenContext)
|
||||
val generator = requirements.renderGenerator(codegenContext, test, protocol)
|
||||
|
||||
test.project.lib {
|
||||
when (variety) {
|
||||
EventStreamTestVariety.Marshall -> writeMarshallTestCases(testCase, generator)
|
||||
EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator)
|
||||
}
|
||||
}
|
||||
|
||||
return test.project
|
||||
}
|
||||
|
||||
private fun <C : CodegenContext> generateTestProject(
|
||||
requirements: EventStreamTestRequirements<C>,
|
||||
codegenContext: C,
|
||||
codegenTarget: CodegenTarget,
|
||||
): TestEventStreamProject {
|
||||
val model = codegenContext.model
|
||||
val symbolProvider = codegenContext.symbolProvider
|
||||
val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape
|
||||
val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape
|
||||
val walker = DirectedWalker(model)
|
||||
|
||||
val project = TestWorkspace.testProject(symbolProvider)
|
||||
val errors = model.serviceShapes
|
||||
.flatMap { walker.walkShapes(it) }
|
||||
.filterIsInstance<StructureShape>()
|
||||
.filter { shape -> shape.hasTrait<ErrorTrait>() }
|
||||
check(errors.isNotEmpty()) { "must have at least one error modeled" }
|
||||
project.withModule(codegenContext.symbolProvider.moduleForShape(errors[0])) {
|
||||
requirements.renderOperationError(this, model, symbolProvider, operationShape)
|
||||
requirements.renderOperationError(this, model, symbolProvider, unionShape)
|
||||
for (shape in errors) {
|
||||
requirements.renderError(project, this, codegenContext, shape)
|
||||
}
|
||||
}
|
||||
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
|
||||
project.withModule(codegenContext.symbolProvider.moduleForShape(inputOutput)) {
|
||||
recursivelyGenerateModels(project, model, symbolProvider, inputOutput, this, codegenTarget)
|
||||
}
|
||||
operationShape.outputShape(model).also { outputShape ->
|
||||
outputShape.renderWithModelBuilder(model, symbolProvider, project)
|
||||
}
|
||||
return TestEventStreamProject(
|
||||
model,
|
||||
codegenContext.serviceShape,
|
||||
operationShape,
|
||||
unionShape,
|
||||
symbolProvider,
|
||||
project,
|
||||
)
|
||||
}
|
||||
|
||||
private fun recursivelyGenerateModels(
|
||||
rustCrate: RustCrate,
|
||||
model: Model,
|
||||
symbolProvider: RustSymbolProvider,
|
||||
shape: Shape,
|
||||
writer: RustWriter,
|
||||
mode: CodegenTarget,
|
||||
) {
|
||||
for (member in shape.members()) {
|
||||
if (member.target.namespace == "smithy.api") {
|
||||
continue
|
||||
}
|
||||
val target = model.expectShape(member.target)
|
||||
when (target) {
|
||||
is StructureShape -> target.renderWithModelBuilder(model, symbolProvider, rustCrate)
|
||||
is UnionShape -> UnionGenerator(
|
||||
model,
|
||||
symbolProvider,
|
||||
writer,
|
||||
target,
|
||||
renderUnknownVariant = mode.renderUnknownVariant(),
|
||||
).render()
|
||||
else -> TODO("EventStreamTestTools doesn't support rendering $target")
|
||||
}
|
||||
recursivelyGenerateModels(rustCrate, model, symbolProvider, target, writer, mode)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -5,17 +5,22 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.core.testutil
|
||||
|
||||
import org.intellij.lang.annotations.Language
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
|
||||
internal object EventStreamUnmarshallTestCases {
|
||||
internal fun RustWriter.writeUnmarshallTestCases(
|
||||
object EventStreamUnmarshallTestCases {
|
||||
fun RustWriter.writeUnmarshallTestCases(
|
||||
testCase: EventStreamTestModels.TestCase,
|
||||
codegenTarget: CodegenTarget,
|
||||
generator: RuntimeType,
|
||||
optionalBuilderInputs: Boolean = false,
|
||||
) {
|
||||
val generator = "crate::event_stream_serde::TestStreamUnmarshaller"
|
||||
|
||||
rust(
|
||||
"""
|
||||
use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage};
|
||||
|
@ -53,202 +58,199 @@ internal object EventStreamUnmarshallTestCases {
|
|||
""",
|
||||
)
|
||||
|
||||
unitTest(
|
||||
name = "message_with_blob",
|
||||
test = """
|
||||
unitTest("message_with_blob") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!");
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithBlob(
|
||||
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
|
||||
MessageWithBlob::builder().data(#{DataInput:W}).build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
|
||||
if (codegenTarget == CodegenTarget.CLIENT) {
|
||||
unitTest(
|
||||
"unknown_message",
|
||||
"""
|
||||
let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!");
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::Unknown,
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"DataInput" to conditionalBuilderInput(
|
||||
"""
|
||||
Blob::new(&b"hello, world!"[..])
|
||||
""",
|
||||
conditional = optionalBuilderInputs,
|
||||
),
|
||||
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_string",
|
||||
"""
|
||||
let message = msg("event", "MessageWithString", "text/plain", b"hello, world!");
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_string") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg("event", "MessageWithString", "text/plain", b"hello, world!");
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithString(MessageWithString::builder().data(#{DataInput}).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"DataInput" to conditionalBuilderInput("\"hello, world!\"", conditional = optionalBuilderInputs),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_struct",
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithStruct",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validTestStruct}"#
|
||||
);
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(
|
||||
unitTest("message_with_struct") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithStruct",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validTestStruct}"##
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(#{StructInput}).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"StructInput" to conditionalBuilderInput(
|
||||
"""
|
||||
TestStruct::builder()
|
||||
.some_string("hello")
|
||||
.some_int(5)
|
||||
.some_string(#{StringInput})
|
||||
.some_int(#{IntInput})
|
||||
.build()
|
||||
).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
|
||||
unitTest(
|
||||
"message_with_union",
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithUnion",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validTestUnion}"#
|
||||
);
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
|
||||
TestUnion::Foo("hello".into())
|
||||
).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
|
||||
unitTest(
|
||||
"message_with_headers",
|
||||
"""
|
||||
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
|
||||
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
|
||||
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
|
||||
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
|
||||
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
|
||||
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
|
||||
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
|
||||
.add_header(Header::new("string", HeaderValue::String("test".into())))
|
||||
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
|
||||
.blob(Blob::new(&b"test"[..]))
|
||||
.boolean(true)
|
||||
.byte(55i8)
|
||||
.int(100_000i32)
|
||||
.long(9_000_000_000i64)
|
||||
.short(16_000i16)
|
||||
.string("test")
|
||||
.timestamp(DateTime::from_secs(5))
|
||||
.build()
|
||||
""",
|
||||
conditional = optionalBuilderInputs,
|
||||
"StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs),
|
||||
"IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs),
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
|
||||
unitTest(
|
||||
"message_with_header_and_payload",
|
||||
"""
|
||||
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
|
||||
.add_header(Header::new("header", HeaderValue::String("header".into())));
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
|
||||
.header("header")
|
||||
.payload(Blob::new(&b"payload"[..]))
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"message_with_no_header_payload_traits",
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithNoHeaderPayloadTraits",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validMessageWithNoHeaderPayloadTraits}"#
|
||||
);
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
|
||||
.some_int(5)
|
||||
.some_string("hello")
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_union") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithUnion",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validTestUnion}"##
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(#{UnionInput}).build()),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"UnionInput" to conditionalBuilderInput("TestUnion::Foo(\"hello\".into())", conditional = optionalBuilderInputs),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest(
|
||||
"some_error",
|
||||
"""
|
||||
let message = msg(
|
||||
"exception",
|
||||
"SomeError",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validSomeError}"#
|
||||
);
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
match expect_error(result.unwrap()) {
|
||||
TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()),
|
||||
kind => panic!("expected SomeError, but got {:?}", kind),
|
||||
}
|
||||
""",
|
||||
)
|
||||
unitTest("message_with_headers") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
|
||||
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
|
||||
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
|
||||
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
|
||||
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
|
||||
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
|
||||
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
|
||||
.add_header(Header::new("string", HeaderValue::String("test".into())))
|
||||
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
|
||||
.blob(#{BlobInput})
|
||||
.boolean(#{BoolInput})
|
||||
.byte(#{ByteInput})
|
||||
.int(#{IntInput})
|
||||
.long(#{LongInput})
|
||||
.short(#{ShortInput})
|
||||
.string(#{StringInput})
|
||||
.timestamp(#{TimestampInput})
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"BlobInput" to conditionalBuilderInput("Blob::new(&b\"test\"[..])", conditional = optionalBuilderInputs),
|
||||
"BoolInput" to conditionalBuilderInput("true", conditional = optionalBuilderInputs),
|
||||
"ByteInput" to conditionalBuilderInput("55i8", conditional = optionalBuilderInputs),
|
||||
"IntInput" to conditionalBuilderInput("100_000i32", conditional = optionalBuilderInputs),
|
||||
"LongInput" to conditionalBuilderInput("9_000_000_000i64", conditional = optionalBuilderInputs),
|
||||
"ShortInput" to conditionalBuilderInput("16_000i16", conditional = optionalBuilderInputs),
|
||||
"StringInput" to conditionalBuilderInput("\"test\"", conditional = optionalBuilderInputs),
|
||||
"TimestampInput" to conditionalBuilderInput("DateTime::from_secs(5)", conditional = optionalBuilderInputs),
|
||||
)
|
||||
}
|
||||
|
||||
if (codegenTarget == CodegenTarget.CLIENT) {
|
||||
unitTest(
|
||||
"error_metadata",
|
||||
unitTest("message_with_header_and_payload") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
|
||||
.add_header(Header::new("header", HeaderValue::String("header".into())));
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
|
||||
.header(#{HeaderInput})
|
||||
.payload(#{PayloadInput})
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"HeaderInput" to conditionalBuilderInput("\"header\"", conditional = optionalBuilderInputs),
|
||||
"PayloadInput" to conditionalBuilderInput("Blob::new(&b\"payload\"[..])", conditional = optionalBuilderInputs),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest("message_with_no_header_payload_traits") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg(
|
||||
"event",
|
||||
"MessageWithNoHeaderPayloadTraits",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validMessageWithNoHeaderPayloadTraits}"##
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
assert_eq!(
|
||||
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
|
||||
.some_int(#{IntInput})
|
||||
.some_string(#{StringInput})
|
||||
.build()
|
||||
),
|
||||
expect_event(result.unwrap())
|
||||
);
|
||||
""",
|
||||
"IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs),
|
||||
"StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs),
|
||||
)
|
||||
}
|
||||
|
||||
unitTest("some_error") {
|
||||
rustTemplate(
|
||||
"""
|
||||
let message = msg(
|
||||
"exception",
|
||||
"UnmodeledError",
|
||||
"SomeError",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validUnmodeledError}"#
|
||||
br##"${testCase.validSomeError}"##
|
||||
);
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
match expect_error(result.unwrap()) {
|
||||
TestStreamError::Unhandled(err) => {
|
||||
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
|
||||
let expected = "message: \"unmodeled error\"";
|
||||
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
|
||||
}
|
||||
kind => panic!("expected error metadata, but got {:?}", kind),
|
||||
TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()),
|
||||
#{AllowUnreachablePatterns:W}
|
||||
kind => panic!("expected SomeError, but got {:?}", kind),
|
||||
}
|
||||
""",
|
||||
"AllowUnreachablePatterns" to writable { Attribute.AllowUnreachablePatterns.render(this) },
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -261,10 +263,21 @@ internal object EventStreamUnmarshallTestCases {
|
|||
"wrong-content-type",
|
||||
br#"${testCase.validTestStruct}"#
|
||||
);
|
||||
let result = ${format(generator)}().unmarshall(&message);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_err(), "expected error, got: {:?}", result);
|
||||
assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be"));
|
||||
""",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun conditionalBuilderInput(
|
||||
@Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") contents: String,
|
||||
conditional: Boolean,
|
||||
vararg ctx: Pair<String, Any>,
|
||||
): Writable =
|
||||
writable {
|
||||
conditionalBlock("Some(", ".into())", conditional = conditional) {
|
||||
rustTemplate(contents, *ctx)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,121 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
|
||||
|
||||
import org.junit.jupiter.api.extension.ExtensionContext
|
||||
import org.junit.jupiter.params.provider.Arguments
|
||||
import org.junit.jupiter.params.provider.ArgumentsProvider
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
|
||||
import software.amazon.smithy.rust.codegen.core.util.getTrait
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings
|
||||
import java.util.stream.Stream
|
||||
|
||||
data class TestCase(
|
||||
val eventStreamTestCase: EventStreamTestModels.TestCase,
|
||||
val publicConstrainedTypes: Boolean,
|
||||
) {
|
||||
override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes"
|
||||
}
|
||||
|
||||
class TestCasesProvider : ArgumentsProvider {
|
||||
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
|
||||
EventStreamTestModels.TEST_CASES
|
||||
.flatMap { testCase ->
|
||||
listOf(
|
||||
TestCase(testCase, publicConstrainedTypes = false),
|
||||
TestCase(testCase, publicConstrainedTypes = true),
|
||||
)
|
||||
}.map { Arguments.of(it) }.stream()
|
||||
}
|
||||
|
||||
abstract class ServerEventStreamBaseRequirements : EventStreamTestRequirements<ServerCodegenContext> {
|
||||
abstract val publicConstrainedTypes: Boolean
|
||||
|
||||
override fun createCodegenContext(
|
||||
model: Model,
|
||||
serviceShape: ServiceShape,
|
||||
protocolShapeId: ShapeId,
|
||||
codegenTarget: CodegenTarget,
|
||||
): ServerCodegenContext = serverTestCodegenContext(
|
||||
model,
|
||||
serviceShape,
|
||||
serverTestRustSettings(
|
||||
codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes),
|
||||
),
|
||||
protocolShapeId,
|
||||
)
|
||||
|
||||
override fun renderBuilderForShape(
|
||||
rustCrate: RustCrate,
|
||||
writer: RustWriter,
|
||||
codegenContext: ServerCodegenContext,
|
||||
shape: StructureShape,
|
||||
) {
|
||||
val validationExceptionConversionGenerator = SmithyValidationExceptionConversionGenerator(codegenContext)
|
||||
if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
|
||||
ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator).apply {
|
||||
render(rustCrate, writer)
|
||||
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
|
||||
renderConvenienceMethod(writer)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator).apply {
|
||||
render(rustCrate, writer)
|
||||
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
|
||||
renderConvenienceMethod(writer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun renderOperationError(
|
||||
writer: RustWriter,
|
||||
model: Model,
|
||||
symbolProvider: RustSymbolProvider,
|
||||
operationOrEventStream: Shape,
|
||||
) {
|
||||
ServerOperationErrorGenerator(model, symbolProvider, operationOrEventStream).render(writer)
|
||||
}
|
||||
|
||||
override fun renderError(
|
||||
rustCrate: RustCrate,
|
||||
writer: RustWriter,
|
||||
codegenContext: ServerCodegenContext,
|
||||
shape: StructureShape,
|
||||
) {
|
||||
StructureGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, listOf()).render()
|
||||
ErrorImplGenerator(
|
||||
codegenContext.model,
|
||||
codegenContext.symbolProvider,
|
||||
writer,
|
||||
shape,
|
||||
shape.getTrait()!!,
|
||||
listOf(),
|
||||
).render(CodegenTarget.SERVER)
|
||||
renderBuilderForShape(rustCrate, writer, codegenContext, shape)
|
||||
}
|
||||
}
|
|
@ -5,49 +5,43 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
|
||||
|
||||
import org.junit.jupiter.api.extension.ExtensionContext
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.Arguments
|
||||
import org.junit.jupiter.params.provider.ArgumentsProvider
|
||||
import org.junit.jupiter.params.provider.ArgumentsSource
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.testModule
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
|
||||
import java.util.stream.Stream
|
||||
|
||||
class ServerEventStreamMarshallerGeneratorTest {
|
||||
@ParameterizedTest
|
||||
@ArgumentsSource(TestCasesProvider::class)
|
||||
fun test(testCase: TestCase) {
|
||||
val testProject = EventStreamTestTools.setupTestCase(
|
||||
testCase.eventStreamTestCase,
|
||||
object : ServerEventStreamBaseRequirements() {
|
||||
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes
|
||||
|
||||
override fun renderGenerator(
|
||||
codegenContext: ServerCodegenContext,
|
||||
project: TestEventStreamProject,
|
||||
protocol: Protocol,
|
||||
): RuntimeType {
|
||||
return EventStreamMarshallerGenerator(
|
||||
project.model,
|
||||
CodegenTarget.SERVER,
|
||||
TestRuntimeConfig,
|
||||
project.symbolProvider,
|
||||
project.streamShape,
|
||||
protocol.structuredDataSerializer(project.operationShape),
|
||||
testCase.eventStreamTestCase.requestContentType,
|
||||
).render()
|
||||
}
|
||||
},
|
||||
CodegenTarget.SERVER,
|
||||
EventStreamTestVariety.Marshall,
|
||||
)
|
||||
testProject.renderInlineMemoryModules()
|
||||
testProject.compileAndTest()
|
||||
serverIntegrationTest(testCase.eventStreamTestCase.model) { _, rustCrate ->
|
||||
rustCrate.testModule {
|
||||
writeMarshallTestCases(testCase.eventStreamTestCase, optionalBuilderInputs = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class TestCase(
|
||||
val eventStreamTestCase: EventStreamTestModels.TestCase,
|
||||
val publicConstrainedTypes: Boolean,
|
||||
) {
|
||||
override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes"
|
||||
}
|
||||
|
||||
class TestCasesProvider : ArgumentsProvider {
|
||||
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
|
||||
EventStreamTestModels.TEST_CASES
|
||||
.flatMap { testCase ->
|
||||
listOf(
|
||||
TestCase(testCase, publicConstrainedTypes = false),
|
||||
TestCase(testCase, publicConstrainedTypes = true),
|
||||
)
|
||||
}.map { Arguments.of(it) }.stream()
|
||||
}
|
||||
|
|
|
@ -7,21 +7,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
|
|||
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.ArgumentsSource
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.testModule
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
|
||||
|
||||
class ServerEventStreamUnmarshallerGeneratorTest {
|
||||
@ParameterizedTest
|
||||
|
@ -33,45 +22,16 @@ class ServerEventStreamUnmarshallerGeneratorTest {
|
|||
return
|
||||
}
|
||||
|
||||
val testProject = EventStreamTestTools.setupTestCase(
|
||||
testCase.eventStreamTestCase,
|
||||
object : ServerEventStreamBaseRequirements() {
|
||||
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes
|
||||
|
||||
override fun renderGenerator(
|
||||
codegenContext: ServerCodegenContext,
|
||||
project: TestEventStreamProject,
|
||||
protocol: Protocol,
|
||||
): RuntimeType {
|
||||
return EventStreamUnmarshallerGenerator(
|
||||
protocol,
|
||||
codegenContext,
|
||||
project.operationShape,
|
||||
project.streamShape,
|
||||
).render()
|
||||
}
|
||||
|
||||
// TODO(https://github.com/awslabs/smithy-rs/issues/1442): Delete this function override to use the correct builder from the parent class
|
||||
override fun renderBuilderForShape(
|
||||
rustCrate: RustCrate,
|
||||
writer: RustWriter,
|
||||
codegenContext: ServerCodegenContext,
|
||||
shape: StructureShape,
|
||||
) {
|
||||
rustCrate.withModule(codegenContext.symbolProvider.moduleForBuilder(shape)) {
|
||||
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, emptyList()).render(this)
|
||||
}
|
||||
rustCrate.moduleFor(shape) {
|
||||
writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
|
||||
BuilderGenerator.renderConvenienceMethod(this, codegenContext.symbolProvider, shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
CodegenTarget.SERVER,
|
||||
EventStreamTestVariety.Unmarshall,
|
||||
transformers = listOf(ConstrainedMemberTransform::transform),
|
||||
)
|
||||
testProject.compileAndTest()
|
||||
serverIntegrationTest(
|
||||
testCase.eventStreamTestCase.model,
|
||||
IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true),
|
||||
) { _, rustCrate ->
|
||||
rustCrate.testModule {
|
||||
writeUnmarshallTestCases(
|
||||
testCase.eventStreamTestCase,
|
||||
optionalBuilderInputs = true,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ pub fn body_is_error(body: &[u8]) -> Result<bool, XmlDecodeError> {
|
|||
Ok(scoped.start_el().matches("ErrorResponse"))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn parse_error_metadata(body: &[u8]) -> Result<ErrorMetadataBuilder, XmlDecodeError> {
|
||||
let mut doc = Document::try_from(body)?;
|
||||
let mut root = doc.root_element()?;
|
||||
|
|
Loading…
Reference in New Issue