Break `codegen-server`'s dependency on `codegen-client` (#2105)

* Move the allow lints customization into `codegen-core`
* Move the crate version customization into `codegen-core`
* Move "pub use" extra into `codegen-core`
* Move `EventStreamSymbolProvider` into `codegen-core`
* Move the streaming shape providers into `codegen-core`
* Refactor event stream marshall/unmarshall tests
* Break `codegen-server` dependency on `codegen-client`
* Split up `EventStreamTestTools`
* Move codegen context creation in event stream tests
* Restructure tests so that #1442 is easier to resolve in the future
* Add client/server prefixes to test classes
* Improve TODO comments in server event stream tests
* Use correct builders for `ServerEventStreamMarshallerGeneratorTest`
* Remove test cases for protocols that don't support event streams
This commit is contained in:
John DiSanti 2022-12-20 13:35:32 -08:00 committed by GitHub
parent 3a3d1210c5
commit 0b4c5ab3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1245 additions and 1011 deletions

View File

@ -21,6 +21,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.Non
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import java.util.logging.Level

View File

@ -7,18 +7,18 @@ package software.amazon.smithy.rust.codegen.client.smithy.customize
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customizations.AllowLintsGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.CrateVersionGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.EndpointPrefixGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpChecksumRequiredGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpVersionListCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.IdempotencyTokenGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyReExportCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseSmithyTypes
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyTypes
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
@ -52,8 +52,7 @@ class RequiredCustomizations : ClientCodegenDecorator {
codegenContext: ClientCodegenContext,
baseCustomizations: List<LibRsCustomization>,
): List<LibRsCustomization> =
baseCustomizations + CrateVersionGenerator() +
AllowLintsGenerator()
baseCustomizations + CrateVersionCustomization() + AllowLintsCustomization()
override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) {
// Add rt-tokio feature for `ByteStream::from_path`

View File

@ -12,6 +12,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer

View File

@ -0,0 +1,59 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.testutil.clientTestRustSettings
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
import java.util.stream.Stream
class TestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream()
}
abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<ClientCodegenContext> {
override fun createCodegenContext(
model: Model,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): ClientCodegenContext = ClientCodegenContext(
model,
testSymbolProvider(model),
serviceShape,
protocolShapeId,
clientTestRustSettings(),
CombinedClientCodegenDecorator(emptyList()),
)
override fun renderBuilderForShape(
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
) {
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape).apply {
render(writer)
writer.implBlock(shape, codegenContext.symbolProvider) {
renderConvenienceMethod(writer)
}
}
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
class ClientEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.runTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
codegenContext: ClientCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType = EventStreamMarshallerGenerator(
project.model,
CodegenTarget.CLIENT,
TestRuntimeConfig,
project.symbolProvider,
project.streamShape,
protocol.structuredDataSerializer(project.operationShape),
testCase.requestContentType,
).render()
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Marshall,
)
}
}

View File

@ -0,0 +1,49 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
class ClientEventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.runTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
codegenContext: ClientCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType {
fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider)
return EventStreamUnmarshallerGenerator(
protocol,
codegenContext,
project.operationShape,
project.streamShape,
::builderSymbol,
).render()
}
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Unmarshall,
)
}
}

View File

@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy
package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
@ -14,13 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors

View File

@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy
package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
@ -14,13 +14,6 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.setDefault
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember

View File

@ -3,19 +3,19 @@
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.customizations
package software.amazon.smithy.rust.codegen.core.smithy.customizations
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
val AllowedRustcLints = listOf(
private val allowedRustcLints = listOf(
// Deprecated items should be safe to compile, so don't block the compilation.
"deprecated",
)
val AllowedClippyLints = listOf(
private val allowedClippyLints = listOf(
// Sometimes operations are named the same as our module e.g. output leading to `output::output`.
"module_inception",
@ -54,16 +54,16 @@ val AllowedClippyLints = listOf(
// "result_large_err",
)
val AllowedRustdocLints = listOf(
private val allowedRustdocLints = listOf(
// Rust >=1.53.0 requires links to be wrapped in `<link>`. This is extremely hard to enforce for
// docs that come from the modeled documentation, so we need to disable this lint
"bare_urls",
)
class AllowLintsGenerator(
private val rustcLints: List<String> = AllowedRustcLints,
private val clippyLints: List<String> = AllowedClippyLints,
private val rustdocLints: List<String> = AllowedRustdocLints,
class AllowLintsCustomization(
private val rustcLints: List<String> = allowedRustcLints,
private val clippyLints: List<String> = allowedClippyLints,
private val rustdocLints: List<String> = allowedRustdocLints,
) : LibRsCustomization() {
override fun section(section: LibRsSection) = when (section) {
is LibRsSection.Attributes -> writable {

View File

@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.customizations
package software.amazon.smithy.rust.codegen.core.smithy.customizations
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
@ -13,7 +13,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
/**
* Add `PGK_VERSION` const in lib.rs to enable knowing the version of the current module
*/
class CrateVersionGenerator : LibRsCustomization() {
class CrateVersionCustomization : LibRsCustomization() {
override fun section(section: LibRsSection) =
writable {
if (section is LibRsSection.Body) {

View File

@ -3,9 +3,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.customizations
package software.amazon.smithy.rust.codegen.core.smithy.customizations
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.rust
@ -30,23 +31,19 @@ private fun hasStreamingOperations(model: Model): Boolean {
}
}
/** Returns true if the model has any blob shapes or members */
private fun hasBlobs(model: Model): Boolean {
return model.structureShapes.any { structure ->
structure.members().any { member -> model.expectShape(member.target).isBlobShape }
// TODO(https://github.com/awslabs/smithy-rs/issues/2111): Fix this logic to consider collection/map shapes
private fun structUnionMembersMatchPredicate(model: Model, predicate: (Shape) -> Boolean): Boolean =
model.structureShapes.any { structure ->
structure.members().any { member -> predicate(model.expectShape(member.target)) }
} || model.unionShapes.any { union ->
union.members().any { member -> model.expectShape(member.target).isBlobShape }
union.members().any { member -> predicate(model.expectShape(member.target)) }
}
}
/** Returns true if the model has any timestamp shapes or members */
private fun hasDateTimes(model: Model): Boolean {
return model.structureShapes.any { structure ->
structure.members().any { member -> model.expectShape(member.target).isTimestampShape }
} || model.unionShapes.any { union ->
union.members().any { member -> model.expectShape(member.target).isTimestampShape }
}
}
/** Returns true if the model uses any blob shapes */
private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isBlobShape)
/** Returns true if the model uses any timestamp shapes */
private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape)
/** Returns a list of types that should be re-exported for the given model */
internal fun pubUseTypes(runtimeConfig: RuntimeConfig, model: Model): List<RuntimeType> {

View File

@ -0,0 +1,208 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.core.testutil
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.dq
internal object EventStreamMarshallTestCases {
internal fun RustWriter.writeMarshallTestCases(
testCase: EventStreamTestModels.TestCase,
generator: RuntimeType,
) {
val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig)
.copy(scope = DependencyScope.Compile)
rustTemplate(
"""
use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage};
use std::collections::HashMap;
use aws_smithy_types::{Blob, DateTime};
use crate::error::*;
use crate::model::*;
use #{validate_body};
use #{MediaType};
fn headers_to_map<'a>(headers: &'a [Header]) -> HashMap<String, &'a HeaderValue> {
let mut map = HashMap::new();
for header in headers {
map.insert(header.name().as_str().to_string(), header.value());
}
map
}
fn str_header(value: &'static str) -> HeaderValue {
HeaderValue::String(value.into())
}
""",
"validate_body" to protocolTestHelpers.toType().resolve("validate_body"),
"MediaType" to protocolTestHelpers.toType().resolve("MediaType"),
)
unitTest(
"message_with_blob",
"""
let event = TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
)
unitTest(
"message_with_string",
"""
let event = TestStream::MessageWithString(
MessageWithString::builder().data("hello, world!").build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
)
unitTest(
"message_with_struct",
"""
let event = TestStream::MessageWithStruct(
MessageWithStruct::builder().some_struct(
TestStruct::builder()
.some_string("hello")
.some_int(5)
.build()
).build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validTestStruct.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
unitTest(
"message_with_union",
"""
let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build());
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validTestUnion.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
unitTest(
"message_with_headers",
"""
let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(DateTime::from_secs(5))
.build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b""[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into())))
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
assert_eq!(expected_message, actual_message);
""",
)
unitTest(
"message_with_header_and_payload",
"""
let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b"payload"[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into())))
.add_header(Header::new("header", HeaderValue::String("header".into())))
.add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into())));
assert_eq!(expected_message, actual_message);
""",
)
unitTest(
"message_with_no_header_payload_traits",
"""
let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
);
let result = ${format(generator)}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validMessageWithNoHeaderPayloadTraits.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
}
}

View File

@ -0,0 +1,179 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.core.testutil
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
private fun fillInBaseModel(
protocolName: String,
extraServiceAnnotations: String = "",
): String = """
namespace test
use aws.protocols#$protocolName
union TestUnion {
Foo: String,
Bar: Integer,
}
structure TestStruct {
someString: String,
someInt: Integer,
}
@error("client")
structure SomeError {
Message: String,
}
structure MessageWithBlob { @eventPayload data: Blob }
structure MessageWithString { @eventPayload data: String }
structure MessageWithStruct { @eventPayload someStruct: TestStruct }
structure MessageWithUnion { @eventPayload someUnion: TestUnion }
structure MessageWithHeaders {
@eventHeader blob: Blob,
@eventHeader boolean: Boolean,
@eventHeader byte: Byte,
@eventHeader int: Integer,
@eventHeader long: Long,
@eventHeader short: Short,
@eventHeader string: String,
@eventHeader timestamp: Timestamp,
}
structure MessageWithHeaderAndPayload {
@eventHeader header: String,
@eventPayload payload: Blob,
}
structure MessageWithNoHeaderPayloadTraits {
someInt: Integer,
someString: String,
}
@streaming
union TestStream {
MessageWithBlob: MessageWithBlob,
MessageWithString: MessageWithString,
MessageWithStruct: MessageWithStruct,
MessageWithUnion: MessageWithUnion,
MessageWithHeaders: MessageWithHeaders,
MessageWithHeaderAndPayload: MessageWithHeaderAndPayload,
MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits,
SomeError: SomeError,
}
structure TestStreamInputOutput { @httpPayload @required value: TestStream }
operation TestStreamOp {
input: TestStreamInputOutput,
output: TestStreamInputOutput,
errors: [SomeError],
}
$extraServiceAnnotations
@$protocolName
service TestService { version: "123", operations: [TestStreamOp] }
"""
object EventStreamTestModels {
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
private fun awsQuery(): Model =
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
private fun ec2Query(): Model =
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
data class TestCase(
val protocolShapeId: String,
val model: Model,
val requestContentType: String,
val responseContentType: String,
val validTestStruct: String,
val validMessageWithNoHeaderPayloadTraits: String,
val validTestUnion: String,
val validSomeError: String,
val validUnmodeledError: String,
val protocolBuilder: (CodegenContext) -> Protocol,
) {
override fun toString(): String = protocolShapeId
}
val TEST_CASES = listOf(
//
// restJson1
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
requestContentType = "application/json",
responseContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
//
// awsJson1_1
//
TestCase(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
//
// restXml
//
TestCase(
protocolShapeId = "aws.protocols#restXml",
model = restXml(),
requestContentType = "application/xml",
responseContentType = "application/xml",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<ErrorResponse>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</ErrorResponse>
""".trimIndent(),
validUnmodeledError = """
<ErrorResponse>
<Error>
<Type>UnmodeledError</Type>
<Code>UnmodeledError</Code>
<Message>unmodeled error</Message>
</Error>
</ErrorResponse>
""".trimIndent(),
) { RestXml(it) },
)
}

View File

@ -0,0 +1,174 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.core.testutil
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule
import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.core.util.outputShape
import kotlin.streams.toList
data class TestEventStreamProject(
val model: Model,
val serviceShape: ServiceShape,
val operationShape: OperationShape,
val streamShape: UnionShape,
val symbolProvider: RustSymbolProvider,
val project: TestWriterDelegator,
)
enum class EventStreamTestVariety {
Marshall,
Unmarshall
}
interface EventStreamTestRequirements<C : CodegenContext> {
/** Create a codegen context for the tests */
fun createCodegenContext(
model: Model,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): C
/** Render the event stream marshall/unmarshall code generator */
fun renderGenerator(
codegenContext: C,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType
/** Render a builder for the given shape */
fun renderBuilderForShape(
writer: RustWriter,
codegenContext: C,
shape: StructureShape,
)
}
object EventStreamTestTools {
fun <C : CodegenContext> runTestCase(
testCase: EventStreamTestModels.TestCase,
requirements: EventStreamTestRequirements<C>,
codegenTarget: CodegenTarget,
variety: EventStreamTestVariety,
) {
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model))
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val codegenContext = requirements.createCodegenContext(
model,
serviceShape,
ShapeId.from(testCase.protocolShapeId),
codegenTarget,
)
val test = generateTestProject(requirements, codegenContext, codegenTarget)
val protocol = testCase.protocolBuilder(codegenContext)
val generator = requirements.renderGenerator(codegenContext, test, protocol)
test.project.lib {
when (variety) {
EventStreamTestVariety.Marshall -> writeMarshallTestCases(testCase, generator)
EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator)
}
}
test.project.compileAndTest()
}
private fun <C : CodegenContext> generateTestProject(
requirements: EventStreamTestRequirements<C>,
codegenContext: C,
codegenTarget: CodegenTarget,
): TestEventStreamProject {
val model = codegenContext.model
val symbolProvider = codegenContext.symbolProvider
val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape
val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape
val project = TestWorkspace.testProject(symbolProvider)
val operationSymbol = symbolProvider.toSymbol(operationShape)
project.withModule(ErrorsModule) {
val errors = model.shapes()
.filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }
.map { it.asStructureShape().get() }
.toList()
when (codegenTarget) {
CodegenTarget.CLIENT -> CombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this)
CodegenTarget.SERVER -> ServerCombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this)
}
for (shape in model.shapes().filter { shape -> shape is StructureShape && shape.hasTrait<ErrorTrait>() }) {
StructureGenerator(model, symbolProvider, this, shape as StructureShape).render(codegenTarget)
requirements.renderBuilderForShape(this, codegenContext, shape)
}
}
project.withModule(ModelsModule) {
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
recursivelyGenerateModels(model, symbolProvider, inputOutput, this, codegenTarget)
}
project.withModule(RustModule.Output) {
operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this)
}
return TestEventStreamProject(
model,
codegenContext.serviceShape,
operationShape,
unionShape,
symbolProvider,
project,
)
}
private fun recursivelyGenerateModels(
model: Model,
symbolProvider: RustSymbolProvider,
shape: Shape,
writer: RustWriter,
mode: CodegenTarget,
) {
for (member in shape.members()) {
if (member.target.namespace == "smithy.api") {
continue
}
val target = model.expectShape(member.target)
when (target) {
is StructureShape -> target.renderWithModelBuilder(model, symbolProvider, writer)
is UnionShape -> UnionGenerator(
model,
symbolProvider,
writer,
target,
renderUnknownVariant = mode.renderUnknownVariant(),
).render()
else -> TODO("EventStreamTestTools doesn't support rendering $target")
}
recursivelyGenerateModels(model, symbolProvider, target, writer, mode)
}
}
}

View File

@ -0,0 +1,274 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.core.testutil
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
internal object EventStreamUnmarshallTestCases {
internal fun RustWriter.writeUnmarshallTestCases(
testCase: EventStreamTestModels.TestCase,
codegenTarget: CodegenTarget,
generator: RuntimeType,
) {
rust(
"""
use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage};
use aws_smithy_types::{Blob, DateTime};
use crate::error::*;
use crate::model::*;
fn msg(
message_type: &'static str,
event_type: &'static str,
content_type: &'static str,
payload: &'static [u8],
) -> Message {
let message = Message::new(payload)
.add_header(Header::new(":message-type", HeaderValue::String(message_type.into())))
.add_header(Header::new(":content-type", HeaderValue::String(content_type.into())));
if message_type == "event" {
message.add_header(Header::new(":event-type", HeaderValue::String(event_type.into())))
} else {
message.add_header(Header::new(":exception-type", HeaderValue::String(event_type.into())))
}
}
fn expect_event<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> T {
match unmarshalled {
UnmarshalledMessage::Event(event) => event,
_ => panic!("expected event, got: {:?}", unmarshalled),
}
}
fn expect_error<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> E {
match unmarshalled {
UnmarshalledMessage::Error(error) => error,
_ => panic!("expected error, got: {:?}", unmarshalled),
}
}
""",
)
unitTest(
name = "message_with_blob",
test = """
let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!");
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
),
expect_event(result.unwrap())
);
""",
)
if (codegenTarget == CodegenTarget.CLIENT) {
unitTest(
"unknown_message",
"""
let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!");
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::Unknown,
expect_event(result.unwrap())
);
""",
)
}
unitTest(
"message_with_string",
"""
let message = msg("event", "MessageWithString", "text/plain", b"hello, world!");
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_struct",
"""
let message = msg(
"event",
"MessageWithStruct",
"${testCase.responseContentType}",
br#"${testCase.validTestStruct}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(
TestStruct::builder()
.some_string("hello")
.some_int(5)
.build()
).build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_union",
"""
let message = msg(
"event",
"MessageWithUnion",
"${testCase.responseContentType}",
br#"${testCase.validTestUnion}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_headers",
"""
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(DateTime::from_secs(5))
.build()
),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_header_and_payload",
"""
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
.add_header(Header::new("header", HeaderValue::String("header".into())));
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_no_header_payload_traits",
"""
let message = msg(
"event",
"MessageWithNoHeaderPayloadTraits",
"${testCase.responseContentType}",
br#"${testCase.validMessageWithNoHeaderPayloadTraits}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
),
expect_event(result.unwrap())
);
""",
)
val (someError, kindSuffix) = when (codegenTarget) {
CodegenTarget.CLIENT -> "TestStreamErrorKind::SomeError" to ".kind"
CodegenTarget.SERVER -> "TestStreamError::SomeError" to ""
}
unitTest(
"some_error",
"""
let message = msg(
"exception",
"SomeError",
"${testCase.responseContentType}",
br#"${testCase.validSomeError}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap())$kindSuffix {
$someError(err) => assert_eq!(Some("some error"), err.message()),
kind => panic!("expected SomeError, but got {:?}", kind),
}
""",
)
if (codegenTarget == CodegenTarget.CLIENT) {
unitTest(
"generic_error",
"""
let message = msg(
"exception",
"UnmodeledError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap())$kindSuffix {
TestStreamErrorKind::Unhandled(err) => {
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
let expected = "message: \"unmodeled error\"";
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
}
kind => panic!("expected generic error, but got {:?}", kind),
}
""",
)
}
unitTest(
"bad_content_type",
"""
let message = msg(
"event",
"MessageWithBlob",
"wrong-content-type",
br#"${testCase.validTestStruct}"#
);
let result = ${format(generator)}().unmarshall(&message);
assert!(result.is_err(), "expected error, got: {:?}", result);
assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be"));
""",
)
}
}

View File

@ -3,11 +3,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.customizations
package software.amazon.smithy.rust.codegen.core.smithy.customizations
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseTypes
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel

View File

@ -24,7 +24,6 @@ val smithyVersion: String by project
dependencies {
implementation(project(":codegen-core"))
implementation(project(":codegen-client"))
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
}

View File

@ -10,10 +10,10 @@ import software.amazon.smithy.build.SmithyBuildPlugin
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.client.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS

View File

@ -10,12 +10,12 @@ import software.amazon.smithy.build.SmithyBuildPlugin
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.client.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.client.smithy.StreamingShapeMetadataProvider
import software.amazon.smithy.rust.codegen.client.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations

View File

@ -5,11 +5,11 @@
package software.amazon.smithy.rust.codegen.server.smithy.customizations
import software.amazon.smithy.rust.codegen.client.smithy.customizations.AllowLintsGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.CrateVersionGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseSmithyTypes
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyTypes
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
@ -29,7 +29,7 @@ class ServerRequiredCustomizations : ServerCodegenDecorator {
codegenContext: ServerCodegenContext,
baseCustomizations: List<LibRsCustomization>,
): List<LibRsCustomization> =
baseCustomizations + CrateVersionGenerator() + AllowLintsGenerator()
baseCustomizations + CrateVersionCustomization() + AllowLintsCustomization()
override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) {
// Add rt-tokio feature for `ByteStream::from_path`

View File

@ -39,12 +39,20 @@ private fun testServiceShapeFor(model: Model) =
fun serverTestSymbolProvider(model: Model, serviceShape: ServiceShape? = null) =
serverTestSymbolProviders(model, serviceShape).symbolProvider
fun serverTestSymbolProviders(model: Model, serviceShape: ServiceShape? = null) =
fun serverTestSymbolProviders(
model: Model,
serviceShape: ServiceShape? = null,
settings: ServerRustSettings? = null,
) =
ServerSymbolProviders.from(
model,
serviceShape ?: testServiceShapeFor(model),
ServerTestSymbolVisitorConfig,
serverTestRustSettings((serviceShape ?: testServiceShapeFor(model)).id).codegenConfig.publicConstrainedTypes,
(
settings ?: serverTestRustSettings(
(serviceShape ?: testServiceShapeFor(model)).id,
)
).codegenConfig.publicConstrainedTypes,
RustCodegenServerPlugin::baseSymbolProvider,
)

View File

@ -1,407 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule
import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryProtocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Ec2QueryProtocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.TestWriterDelegator
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import java.util.stream.Stream
import kotlin.streams.toList
private fun fillInBaseModel(
protocolName: String,
extraServiceAnnotations: String = "",
): String = """
namespace test
use aws.protocols#$protocolName
union TestUnion {
Foo: String,
Bar: Integer,
}
structure TestStruct {
someString: String,
someInt: Integer,
}
@error("client")
structure SomeError {
Message: String,
}
structure MessageWithBlob { @eventPayload data: Blob }
structure MessageWithString { @eventPayload data: String }
structure MessageWithStruct { @eventPayload someStruct: TestStruct }
structure MessageWithUnion { @eventPayload someUnion: TestUnion }
structure MessageWithHeaders {
@eventHeader blob: Blob,
@eventHeader boolean: Boolean,
@eventHeader byte: Byte,
@eventHeader int: Integer,
@eventHeader long: Long,
@eventHeader short: Short,
@eventHeader string: String,
@eventHeader timestamp: Timestamp,
}
structure MessageWithHeaderAndPayload {
@eventHeader header: String,
@eventPayload payload: Blob,
}
structure MessageWithNoHeaderPayloadTraits {
someInt: Integer,
someString: String,
}
@streaming
union TestStream {
MessageWithBlob: MessageWithBlob,
MessageWithString: MessageWithString,
MessageWithStruct: MessageWithStruct,
MessageWithUnion: MessageWithUnion,
MessageWithHeaders: MessageWithHeaders,
MessageWithHeaderAndPayload: MessageWithHeaderAndPayload,
MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits,
SomeError: SomeError,
}
structure TestStreamInputOutput { @httpPayload @required value: TestStream }
operation TestStreamOp {
input: TestStreamInputOutput,
output: TestStreamInputOutput,
errors: [SomeError],
}
$extraServiceAnnotations
@$protocolName
service TestService { version: "123", operations: [TestStreamOp] }
"""
object EventStreamTestModels {
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
private fun awsQuery(): Model =
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
private fun ec2Query(): Model =
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
data class TestCase(
val protocolShapeId: String,
val model: Model,
val requestContentType: String,
val responseContentType: String,
val validTestStruct: String,
val validMessageWithNoHeaderPayloadTraits: String,
val validTestUnion: String,
val validSomeError: String,
val validUnmodeledError: String,
val target: CodegenTarget = CodegenTarget.CLIENT,
val protocolBuilder: (CodegenContext) -> Protocol,
) {
override fun toString(): String = protocolShapeId
}
private val testCases = listOf(
//
// restJson1
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
requestContentType = "application/json",
responseContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
//
// restJson1, server mode
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
requestContentType = "application/json",
responseContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
//
// awsJson1_1
//
TestCase(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
//
// restXml
//
TestCase(
protocolShapeId = "aws.protocols#restXml",
model = restXml(),
requestContentType = "application/xml",
responseContentType = "application/xml",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<ErrorResponse>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</ErrorResponse>
""".trimIndent(),
validUnmodeledError = """
<ErrorResponse>
<Error>
<Type>UnmodeledError</Type>
<Code>UnmodeledError</Code>
<Message>unmodeled error</Message>
</Error>
</ErrorResponse>
""".trimIndent(),
) { RestXml(it) },
//
// awsQuery
//
TestCase(
protocolShapeId = "aws.protocols#awsQuery",
model = awsQuery(),
requestContentType = "application/x-www-form-urlencoded",
responseContentType = "text/xml",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<ErrorResponse>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</ErrorResponse>
""".trimIndent(),
validUnmodeledError = """
<ErrorResponse>
<Error>
<Type>UnmodeledError</Type>
<Code>UnmodeledError</Code>
<Message>unmodeled error</Message>
</Error>
</ErrorResponse>
""".trimIndent(),
) { AwsQueryProtocol(it) },
//
// ec2Query
//
TestCase(
protocolShapeId = "aws.protocols#ec2Query",
model = ec2Query(),
requestContentType = "application/x-www-form-urlencoded",
responseContentType = "text/xml",
validTestStruct = """
<TestStruct>
<someString>hello</someString>
<someInt>5</someInt>
</TestStruct>
""".trimIndent(),
validMessageWithNoHeaderPayloadTraits = """
<MessageWithNoHeaderPayloadTraits>
<someString>hello</someString>
<someInt>5</someInt>
</MessageWithNoHeaderPayloadTraits>
""".trimIndent(),
validTestUnion = "<TestUnion><Foo>hello</Foo></TestUnion>",
validSomeError = """
<Response>
<Errors>
<Error>
<Type>SomeError</Type>
<Code>SomeError</Code>
<Message>some error</Message>
</Error>
</Errors>
</Response>
""".trimIndent(),
validUnmodeledError = """
<Response>
<Errors>
<Error>
<Type>UnmodeledError</Type>
<Code>UnmodeledError</Code>
<Message>unmodeled error</Message>
</Error>
</Errors>
</Response>
""".trimIndent(),
) { Ec2QueryProtocol(it) },
)
// TODO(https://github.com/awslabs/smithy-rs/issues/1442) Server tests
// should be run from the server subproject using the
// `serverTestSymbolProvider()`.
// .flatMap { listOf(it, it.copy(target = CodegenTarget.SERVER)) }
class UnmarshallTestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
testCases.map { Arguments.of(it) }.stream()
}
class MarshallTestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
// Don't include awsQuery or ec2Query for now since marshall support for them is unimplemented
testCases
.filter { testCase -> !testCase.protocolShapeId.contains("Query") }
.map { Arguments.of(it) }.stream()
}
}
data class TestEventStreamProject(
val model: Model,
val serviceShape: ServiceShape,
val operationShape: OperationShape,
val streamShape: UnionShape,
val symbolProvider: RustSymbolProvider,
val project: TestWriterDelegator,
)
object EventStreamTestTools {
fun generateTestProject(testCase: EventStreamTestModels.TestCase): TestEventStreamProject {
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model))
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape
val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape
val symbolProvider = when (testCase.target) {
CodegenTarget.CLIENT -> testSymbolProvider(model)
CodegenTarget.SERVER -> serverTestSymbolProvider(model)
}
val project = TestWorkspace.testProject(symbolProvider)
val operationSymbol = symbolProvider.toSymbol(operationShape)
project.withModule(ErrorsModule) {
val errors = model.shapes()
.filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }
.map { it.asStructureShape().get() }
.toList()
when (testCase.target) {
CodegenTarget.CLIENT -> CombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this)
CodegenTarget.SERVER -> ServerCombinedErrorGenerator(model, symbolProvider, operationSymbol, errors).render(this)
}
for (shape in model.shapes().filter { shape -> shape.isStructureShape && shape.hasTrait<ErrorTrait>() }) {
StructureGenerator(model, symbolProvider, this, shape as StructureShape).render(testCase.target)
val builderGen = BuilderGenerator(model, symbolProvider, shape)
builderGen.render(this)
implBlock(shape, symbolProvider) {
builderGen.renderConvenienceMethod(this)
}
}
}
project.withModule(ModelsModule) {
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
recursivelyGenerateModels(model, symbolProvider, inputOutput, this, testCase.target)
}
project.withModule(RustModule.Output) {
operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this)
}
return TestEventStreamProject(model, serviceShape, operationShape, unionShape, symbolProvider, project)
}
private fun recursivelyGenerateModels(
model: Model,
symbolProvider: RustSymbolProvider,
shape: Shape,
writer: RustWriter,
mode: CodegenTarget,
) {
for (member in shape.members()) {
val target = model.expectShape(member.target)
if (target is StructureShape || target is UnionShape) {
if (target is StructureShape) {
target.renderWithModelBuilder(model, symbolProvider, writer)
} else if (target is UnionShape) {
UnionGenerator(model, symbolProvider, writer, target, renderUnknownVariant = mode.renderUnknownVariant()).render()
}
recursivelyGenerateModels(model, symbolProvider, target, writer, mode)
}
}
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings
import java.util.stream.Stream
data class TestCase(
val eventStreamTestCase: EventStreamTestModels.TestCase,
val publicConstrainedTypes: Boolean,
) {
override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes"
}
class TestCasesProvider : ArgumentsProvider {
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
EventStreamTestModels.TEST_CASES
.flatMap { testCase ->
listOf(
TestCase(testCase, publicConstrainedTypes = false),
TestCase(testCase, publicConstrainedTypes = true),
)
}.map { Arguments.of(it) }.stream()
}
abstract class ServerEventStreamBaseRequirements : EventStreamTestRequirements<ServerCodegenContext> {
abstract val publicConstrainedTypes: Boolean
override fun createCodegenContext(
model: Model,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): ServerCodegenContext = serverTestCodegenContext(
model, serviceShape,
serverTestRustSettings(
codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes),
),
protocolShapeId,
)
override fun renderBuilderForShape(
writer: RustWriter,
codegenContext: ServerCodegenContext,
shape: StructureShape,
) {
if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
ServerBuilderGenerator(codegenContext, shape).apply {
render(writer)
writer.implBlock(shape, codegenContext.symbolProvider) {
renderConvenienceMethod(writer)
}
}
} else {
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape).apply {
render(writer)
writer.implBlock(shape, codegenContext.symbolProvider) {
renderConvenienceMethod(writer)
}
}
}
}
}

View File

@ -0,0 +1,49 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
class ServerEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: TestCase) {
EventStreamTestTools.runTestCase(
testCase.eventStreamTestCase,
object : ServerEventStreamBaseRequirements() {
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes
override fun renderGenerator(
codegenContext: ServerCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType {
return EventStreamMarshallerGenerator(
project.model,
CodegenTarget.SERVER,
TestRuntimeConfig,
project.symbolProvider,
project.streamShape,
protocol.structuredDataSerializer(project.operationShape),
testCase.eventStreamTestCase.requestContentType,
).render()
}
},
CodegenTarget.SERVER,
EventStreamTestVariety.Marshall,
)
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol
class ServerEventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: TestCase) {
// TODO(https://github.com/awslabs/smithy-rs/issues/1442): Enable tests for `publicConstrainedTypes = false`
// by deleting this if/return
if (!testCase.publicConstrainedTypes) {
return
}
EventStreamTestTools.runTestCase(
testCase.eventStreamTestCase,
object : ServerEventStreamBaseRequirements() {
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes
override fun renderGenerator(
codegenContext: ServerCodegenContext,
project: TestEventStreamProject,
protocol: Protocol,
): RuntimeType {
fun builderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol(codegenContext)
return EventStreamUnmarshallerGenerator(
protocol,
codegenContext,
project.operationShape,
project.streamShape,
::builderSymbol,
).render()
}
// TODO(https://github.com/awslabs/smithy-rs/issues/1442): Delete this function override to use the correct builder from the parent class
override fun renderBuilderForShape(
writer: RustWriter,
codegenContext: ServerCodegenContext,
shape: StructureShape,
) {
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape).apply {
render(writer)
writer.implBlock(shape, codegenContext.symbolProvider) {
renderConvenienceMethod(writer)
}
}
}
},
CodegenTarget.SERVER,
EventStreamTestVariety.Unmarshall,
)
}
}

View File

@ -1,306 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.parse
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestModels
import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestTools
class EventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(EventStreamTestModels.UnmarshallTestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
val test = EventStreamTestTools.generateTestProject(testCase)
val codegenContext = CodegenContext(
test.model,
test.symbolProvider,
test.serviceShape,
ShapeId.from(testCase.protocolShapeId),
testRustSettings(),
target = testCase.target,
)
val protocol = testCase.protocolBuilder(codegenContext)
fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider)
val generator = EventStreamUnmarshallerGenerator(
protocol,
codegenContext,
test.operationShape,
test.streamShape,
::builderSymbol,
)
test.project.lib {
rust(
"""
use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage};
use aws_smithy_types::{Blob, DateTime};
use crate::error::*;
use crate::model::*;
fn msg(
message_type: &'static str,
event_type: &'static str,
content_type: &'static str,
payload: &'static [u8],
) -> Message {
let message = Message::new(payload)
.add_header(Header::new(":message-type", HeaderValue::String(message_type.into())))
.add_header(Header::new(":content-type", HeaderValue::String(content_type.into())));
if message_type == "event" {
message.add_header(Header::new(":event-type", HeaderValue::String(event_type.into())))
} else {
message.add_header(Header::new(":exception-type", HeaderValue::String(event_type.into())))
}
}
fn expect_event<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> T {
match unmarshalled {
UnmarshalledMessage::Event(event) => event,
_ => panic!("expected event, got: {:?}", unmarshalled),
}
}
fn expect_error<T: std::fmt::Debug, E: std::fmt::Debug>(unmarshalled: UnmarshalledMessage<T, E>) -> E {
match unmarshalled {
UnmarshalledMessage::Error(error) => error,
_ => panic!("expected error, got: {:?}", unmarshalled),
}
}
""",
)
unitTest(
name = "message_with_blob",
test = """
let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!");
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
),
expect_event(result.unwrap())
);
""",
)
if (testCase.target == CodegenTarget.CLIENT) {
unitTest(
"unknown_message",
"""
let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!");
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::Unknown,
expect_event(result.unwrap())
);
""",
)
}
unitTest(
"message_with_string",
"""
let message = msg("event", "MessageWithString", "text/plain", b"hello, world!");
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_struct",
"""
let message = msg(
"event",
"MessageWithStruct",
"${testCase.responseContentType}",
br#"${testCase.validTestStruct}"#
);
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(
TestStruct::builder()
.some_string("hello")
.some_int(5)
.build()
).build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_union",
"""
let message = msg(
"event",
"MessageWithUnion",
"${testCase.responseContentType}",
br#"${testCase.validTestUnion}"#
);
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build()),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_headers",
"""
let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"")
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(DateTime::from_secs(5))
.build()
),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_header_and_payload",
"""
let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload")
.add_header(Header::new("header", HeaderValue::String("header".into())));
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
),
expect_event(result.unwrap())
);
""",
)
unitTest(
"message_with_no_header_payload_traits",
"""
let message = msg(
"event",
"MessageWithNoHeaderPayloadTraits",
"${testCase.responseContentType}",
br#"${testCase.validMessageWithNoHeaderPayloadTraits}"#
);
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
assert_eq!(
TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
),
expect_event(result.unwrap())
);
""",
)
val (someError, kindSuffix) = when (testCase.target) {
CodegenTarget.CLIENT -> listOf("TestStreamErrorKind::SomeError", ".kind")
CodegenTarget.SERVER -> listOf("TestStreamError::SomeError", "")
}
unitTest(
"some_error",
"""
let message = msg(
"exception",
"SomeError",
"${testCase.responseContentType}",
br#"${testCase.validSomeError}"#
);
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap())$kindSuffix {
$someError(err) => assert_eq!(Some("some error"), err.message()),
kind => panic!("expected SomeError, but got {:?}", kind),
}
""",
)
if (testCase.target == CodegenTarget.CLIENT) {
unitTest(
"generic_error",
"""
let message = msg(
"exception",
"UnmodeledError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
);
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
match expect_error(result.unwrap())$kindSuffix {
TestStreamErrorKind::Unhandled(err) => {
let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err));
let expected = "message: \"unmodeled error\"";
assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'");
}
kind => panic!("expected generic error, but got {:?}", kind),
}
""",
)
}
unitTest(
"bad_content_type",
"""
let message = msg(
"event",
"MessageWithBlob",
"wrong-content-type",
br#"${testCase.validTestStruct}"#
);
let result = ${format(generator.render())}().unmarshall(&message);
assert!(result.is_err(), "expected error, got: {:?}", result);
assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be"));
""",
)
}
test.project.compileAndTest()
}
}

View File

@ -1,240 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestModels
import software.amazon.smithy.rust.codegen.server.smithy.protocols.EventStreamTestTools
class EventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(EventStreamTestModels.MarshallTestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
val test = EventStreamTestTools.generateTestProject(testCase)
val codegenContext = CodegenContext(
test.model,
test.symbolProvider,
test.serviceShape,
ShapeId.from(testCase.protocolShapeId),
testRustSettings(),
target = testCase.target,
)
val protocol = testCase.protocolBuilder(codegenContext)
val generator = EventStreamMarshallerGenerator(
test.model,
testCase.target,
TestRuntimeConfig,
test.symbolProvider,
test.streamShape,
protocol.structuredDataSerializer(test.operationShape),
testCase.requestContentType,
)
test.project.lib {
val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig)
.copy(scope = DependencyScope.Compile)
rustTemplate(
"""
use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage};
use std::collections::HashMap;
use aws_smithy_types::{Blob, DateTime};
use crate::error::*;
use crate::model::*;
use #{validate_body};
use #{MediaType};
fn headers_to_map<'a>(headers: &'a [Header]) -> HashMap<String, &'a HeaderValue> {
let mut map = HashMap::new();
for header in headers {
map.insert(header.name().as_str().to_string(), header.value());
}
map
}
fn str_header(value: &'static str) -> HeaderValue {
HeaderValue::String(value.into())
}
""",
"validate_body" to protocolTestHelpers.toType().resolve("validate_body"),
"MediaType" to protocolTestHelpers.toType().resolve("MediaType"),
)
unitTest(
"message_with_blob",
"""
let event = TestStream::MessageWithBlob(
MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build()
);
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
)
unitTest(
"message_with_string",
"""
let event = TestStream::MessageWithString(
MessageWithString::builder().data("hello, world!").build()
);
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap());
assert_eq!(&b"hello, world!"[..], message.payload());
""",
)
unitTest(
"message_with_struct",
"""
let event = TestStream::MessageWithStruct(
MessageWithStruct::builder().some_struct(
TestStruct::builder()
.some_string("hello")
.some_int(5)
.build()
).build()
);
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validTestStruct.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
unitTest(
"message_with_union",
"""
let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(
TestUnion::Foo("hello".into())
).build());
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validTestUnion.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
unitTest(
"message_with_headers",
"""
let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder()
.blob(Blob::new(&b"test"[..]))
.boolean(true)
.byte(55i8)
.int(100_000i32)
.long(9_000_000_000i64)
.short(16_000i16)
.string("test")
.timestamp(DateTime::from_secs(5))
.build()
);
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b""[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into())))
.add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into())))
.add_header(Header::new("boolean", HeaderValue::Bool(true)))
.add_header(Header::new("byte", HeaderValue::Byte(55i8)))
.add_header(Header::new("int", HeaderValue::Int32(100_000i32)))
.add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64)))
.add_header(Header::new("short", HeaderValue::Int16(16_000i16)))
.add_header(Header::new("string", HeaderValue::String("test".into())))
.add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5))));
assert_eq!(expected_message, actual_message);
""",
)
unitTest(
"message_with_header_and_payload",
"""
let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder()
.header("header")
.payload(Blob::new(&b"payload"[..]))
.build()
);
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let actual_message = result.unwrap();
let expected_message = Message::new(&b"payload"[..])
.add_header(Header::new(":message-type", HeaderValue::String("event".into())))
.add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into())))
.add_header(Header::new("header", HeaderValue::String("header".into())))
.add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into())));
assert_eq!(expected_message, actual_message);
""",
)
unitTest(
"message_with_no_header_payload_traits",
"""
let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder()
.some_int(5)
.some_string("hello")
.build()
);
let result = ${format(generator.render())}().marshall(event);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
let message = result.unwrap();
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
${testCase.validMessageWithNoHeaderPayloadTraits.dq()},
MediaType::from(${testCase.requestContentType.dq()})
).unwrap();
""",
)
}
test.project.compileAndTest()
}
}