Fix constraint-related errors in Rpcv2CBOR server implementation (#3794)

This commit is contained in:
AWS SDK Rust Bot 2024-10-01 14:04:57 +01:00 committed by GitHub
commit 191c5771e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 437 additions and 43 deletions

9
.changelog/2155171.md Normal file
View File

@ -0,0 +1,9 @@
---
applies_to: ["server","client"]
authors: ["drganjoo"]
references: [smithy-rs#3573]
breaking: false
new_feature: true
bug_fix: false
---
Support for the [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) protocol has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission.

View File

@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.testModule
@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
"exception",
"UnmodeledError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);

View File

@ -25,6 +25,7 @@ dependencies {
implementation("org.jsoup:jsoup:1.16.2")
api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
api("com.moandjiezana.toml:toml4j:0.7.2")
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")

View File

@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
@ -140,9 +141,23 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
TODO("rpcv2Cbor event streams have not yet been implemented")
ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName ->
rustTemplate(
"""
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> {
#{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload)
}
""",
"cbor_errors" to RuntimeType.cborErrors(runtimeConfig),
"Bytes" to RuntimeType.Bytes,
"ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
"DeserializeError" to
CargoDependency.smithyCbor(runtimeConfig).toType()
.resolve("decode::DeserializeError"),
"Headers" to RuntimeType.headers(runtimeConfig),
)
}
// Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
// unless there is no input or if the operation is an event stream, see

View File

@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
@ -447,7 +446,24 @@ class CborParserGenerator(
}
override fun payloadParser(member: MemberShape): RuntimeType {
UNREACHABLE("No protocol using CBOR serialization supports payload binding")
val shape = model.expectShape(member.target)
val returnSymbol = returnSymbolToParse(shape)
check(shape is UnionShape || shape is StructureShape) {
"Payload parser should only be used on structure and union shapes."
}
return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
rustTemplate(
"""
pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
let decoder = &mut #{Decoder}::new(value);
#{DeserializeMember}
}
""",
"ReturnType" to returnSymbol.symbol,
"DeserializeMember" to deserializeMember(member),
*codegenScope,
)
}
}
override fun operationParser(operationShape: OperationShape): RuntimeType? {

View File

@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) {
/** Manipulate the serializer context for a map prior to it being serialized. **/
data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) :
CborSerializerSection("BeforeIteratingOverMapOrCollection")
/** Manipulate the serializer context for a non-null member prior to it being serialized. **/
data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) :
CborSerializerSection("BeforeSerializingNonNullMember")
}
/**
@ -200,9 +204,26 @@ class CborSerializerGenerator(
}
}
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun payloadSerializer(member: MemberShape): RuntimeType {
TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
val target = model.expectShape(member.target)
return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
rustBlockTemplate(
"pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
*codegenScope,
"target" to symbolProvider.toSymbol(target),
) {
rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
rustBlock("") {
rust("let encoder = &mut encoder;")
when (target) {
is StructureShape -> serializeStructure(StructContext("input", target))
is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
}
}
rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
}
}
}
override fun unsetStructure(structure: StructureShape): RuntimeType =
@ -311,6 +332,7 @@ class CborSerializerGenerator(
safeName().also { local ->
rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
context.valueExpression = ValueExpression.Reference(local)
resolveValueExpressionForConstrainedType(targetShape, context)
serializeMemberValue(context, targetShape)
}
if (context.writeNulls) {
@ -320,6 +342,7 @@ class CborSerializerGenerator(
}
}
} else {
resolveValueExpressionForConstrainedType(targetShape, context)
with(serializerUtil) {
ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) {
serializeMemberValue(context, targetShape)
@ -328,6 +351,20 @@ class CborSerializerGenerator(
}
}
private fun RustWriter.resolveValueExpressionForConstrainedType(
targetShape: Shape,
context: MemberContext,
) {
for (customization in customizations) {
customization.section(
CborSerializerSection.BeforeSerializingNonNullMember(
targetShape,
context,
),
)(this)
}
}
private fun RustWriter.serializeMemberValue(
context: MemberContext,
target: Shape,
@ -362,7 +399,7 @@ class CborSerializerGenerator(
rust("$encoder;") // Encode the member key.
}
when (target) {
is StructureShape -> serializeStructure(StructContext(value.name, target))
is StructureShape -> serializeStructure(StructContext(value.asRef(), target))
is CollectionShape -> serializeCollection(Context(value, target))
is MapShape -> serializeMap(Context(value, target))
is UnionShape -> serializeUnion(Context(value, target))

View File

@ -5,6 +5,8 @@
package software.amazon.smithy.rust.codegen.core.testutil
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
import java.util.Base64
private fun fillInBaseModel(
protocolName: String,
namespacedProtocolName: String,
extraServiceAnnotations: String = "",
): String =
"""
namespace test
use smithy.framework#ValidationException
use aws.protocols#$protocolName
use $namespacedProtocolName
union TestUnion {
Foo: String,
@ -86,22 +90,24 @@ private fun fillInBaseModel(
}
$extraServiceAnnotations
@$protocolName
@${namespacedProtocolName.substringAfter("#")}
service TestService { version: "123", operations: [TestStreamOp] }
"""
object EventStreamTestModels {
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()
private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()
private fun awsQuery(): Model =
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
private fun ec2Query(): Model =
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
data class TestCase(
val protocolShapeId: String,
@ -120,39 +126,67 @@ object EventStreamTestModels {
override fun toString(): String = protocolShapeId
}
private fun base64Encode(input: ByteArray): String {
val encodedBytes = Base64.getEncoder().encode(input)
return String(encodedBytes)
}
private fun createCborFromJson(jsonString: String): ByteArray {
val jsonMapper = ObjectMapper()
val cborMapper = ObjectMapper(CBORFactory())
// Parse JSON string to a generic type.
val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
// Convert the parsed data to CBOR.
return cborMapper.writeValueAsBytes(jsonData)
}
private val restJsonTestCase =
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) }
val TEST_CASES =
listOf(
//
// restJson1
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
restJsonTestCase,
//
// rpcV2Cbor
//
restJsonTestCase.copy(
protocolShapeId = "smithy.protocols#rpcv2Cbor",
model = rpcv2Cbor(),
mediaType = "application/cbor",
responseContentType = "application/cbor",
eventStreamMessageContentType = "application/cbor",
validTestStruct = base64Encode(createCborFromJson(restJsonTestCase.validTestStruct)),
validMessageWithNoHeaderPayloadTraits = base64Encode(createCborFromJson(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
validTestUnion = base64Encode(createCborFromJson(restJsonTestCase.validTestUnion)),
validSomeError = base64Encode(createCborFromJson(restJsonTestCase.validSomeError)),
validUnmodeledError = base64Encode(createCborFromJson(restJsonTestCase.validUnmodeledError)),
protocolBuilder = { RpcV2Cbor(it) },
),
//
// awsJson1_1
//
TestCase(
restJsonTestCase.copy(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
mediaType = "application/x-amz-json-1.1",
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
//
// restXml

View File

@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.util.lookup
import java.util.Base64
object EventStreamUnmarshallTestCases {
fun RustWriter.writeUnmarshallTestCases(
@ -109,7 +110,7 @@ object EventStreamUnmarshallTestCases {
"event",
"MessageWithStruct",
"${testCase.responseContentType}",
br##"${testCase.validTestStruct}"##
${testCase.generateRustPayloadInitializer(testCase.validTestStruct)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
@ -140,7 +141,7 @@ object EventStreamUnmarshallTestCases {
"event",
"MessageWithUnion",
"${testCase.responseContentType}",
br##"${testCase.validTestUnion}"##
${testCase.generateRustPayloadInitializer(testCase.validTestUnion)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
@ -221,7 +222,7 @@ object EventStreamUnmarshallTestCases {
"event",
"MessageWithNoHeaderPayloadTraits",
"${testCase.responseContentType}",
br##"${testCase.validMessageWithNoHeaderPayloadTraits}"##
${testCase.generateRustPayloadInitializer(testCase.validMessageWithNoHeaderPayloadTraits)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
@ -246,7 +247,7 @@ object EventStreamUnmarshallTestCases {
"exception",
"SomeError",
"${testCase.responseContentType}",
br##"${testCase.validSomeError}"##
${testCase.generateRustPayloadInitializer(testCase.validSomeError)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
@ -267,7 +268,7 @@ object EventStreamUnmarshallTestCases {
"event",
"MessageWithBlob",
"wrong-content-type",
br#"${testCase.validTestStruct}"#
${testCase.generateRustPayloadInitializer(testCase.validTestStruct)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_err(), "expected error, got: {:?}", result);
@ -275,6 +276,35 @@ object EventStreamUnmarshallTestCases {
""",
)
}
/**
* Generates a Rust-compatible initializer string for a given payload.
*
* This function handles two different scenarios based on the event stream message content type:
*
* 1. For CBOR payloads (content type "application/cbor"):
* - The input payload is expected to be a base64 encoded CBOR value.
* - It decodes the base64 string and generates a Rust byte array initializer.
* - The output format is: &[0xFFu8, 0xFFu8, ...] where FF are hexadecimal values.
*
* 2. For all other content types:
* - It returns a Rust raw string literal initializer.
* - The output format is: br##"original_payload"##
*/
fun EventStreamTestModels.TestCase.generateRustPayloadInitializer(payload: String): String {
return if (this.eventStreamMessageContentType == "application/cbor") {
Base64.getDecoder().decode(payload)
.joinToString(
prefix = "&[",
postfix = "]",
transform = { "0x${it.toUByte().toString(16).padStart(2, '0')}u8" },
)
} else {
"""
br##"$payload"##
"""
}
}
}
internal fun conditionalBuilderInput(

View File

@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger
import java.util.logging.Logger
@ -208,6 +209,8 @@ open class ServerCodegenVisitor(
.let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) }
// Tag aggregate shapes reachable from operation input
.let(ShapesReachableFromOperationInputTagger::transform)
// Remove traits that are not supported by the chosen protocol.
.let { ServerProtocolBasedTransformationFactory.transform(it, settings) }
// Normalize event stream operations
.let(EventStreamNormalizer::transform)

View File

@ -0,0 +1,47 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.customizations
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.ByteShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.LongShape
import software.amazon.smithy.model.shapes.ShortShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType
/**
* Constrained shapes are wrapped in a Rust tuple struct that implements all necessary checks. However,
* for serialization purposes, the inner type of the constrained shape is used for serialization.
*
* The `BeforeSerializingMemberCborCustomization` class generates a reference to the inner type when the shape being
* code-generated is constrained and the `publicConstrainedTypes` codegen flag is set.
*/
class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() {
override fun section(section: CborSerializerSection): Writable =
when (section) {
is CborSerializerSection.BeforeSerializingNonNullMember ->
writable {
if (workingWithPublicConstrainedWrapperTupleType(
section.shape,
codegenContext.model,
codegenContext.settings.codegenConfig.publicConstrainedTypes,
)
) {
if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) {
section.context.valueExpression =
ValueExpression.Reference("&${section.context.valueExpression.name}.0")
}
}
}
else -> emptySection
}
}

View File

@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization
import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization
import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization
import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol(
listOf(
BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext),
AddTypeFieldToServerErrorsCborCustomization(),
BeforeSerializingMemberCborCustomization(serverCodegenContext),
),
)
}

View File

@ -0,0 +1,77 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.transformers
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.AbstractShapeBuilder
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.utils.SmithyBuilder
import software.amazon.smithy.utils.ToSmithyBuilder
/**
* Each protocol may not support all of the features that Smithy allows. For instance, `rpcv2Cbor`
* does not support HTTP bindings other than `@httpError`. `ServerProtocolBasedTransformationFactory` is a factory
* object that transforms the model and removes specific traits based on the protocol being instantiated.
*
* In the long term, this class will be removed, and each protocol should be resilient enough to ignore extra
* traits that the model is annotated with. This will be addressed when we fix issue
* [#2979](https://github.com/smithy-lang/smithy-rs/issues/2979).
*/
object ServerProtocolBasedTransformationFactory {
fun transform(
model: Model,
settings: ServerRustSettings,
): Model {
val service = settings.getService(model)
if (!service.hasTrait<Rpcv2CborTrait>()) {
return model
}
// `rpcv2Cbor` does not support:
// 1. `@httpPayload` trait.
// 2. `@httpLabel` trait.
// 3. `@streaming` trait applied to a `Blob` (data streaming).
return ModelTransformer.create().mapShapes(model) { shape ->
when (shape) {
is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID)
is MemberShape -> {
shape
.removeTraitIfPresent(HttpLabelTrait.ID)
.removeTraitIfPresent(HttpPayloadTrait.ID)
}
is BlobShape -> {
shape.removeTraitIfPresent(StreamingTrait.ID)
}
else -> shape
}
}
}
fun <T : Shape, B> T.removeTraitIfPresent(
traitId: ShapeId,
): T
where T : ToSmithyBuilder<T>,
B : AbstractShapeBuilder<B, T>,
B : SmithyBuilder<T> {
return if (this.hasTrait(traitId)) {
@Suppress("UNCHECKED_CAST")
(this.toBuilder() as B).removeTrait(traitId).build()
} else {
this
}
}
}

View File

@ -6,17 +6,103 @@
package software.amazon.smithy.rust.codegen.server.smithy
import io.kotest.inspectors.forAll
import io.kotest.matchers.ints.shouldBeGreaterThan
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.AbstractTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import java.io.File
enum class ModelProtocol(val trait: AbstractTrait) {
AwsJson10(AwsJson1_0Trait.builder().build()),
AwsJson11(AwsJson1_1Trait.builder().build()),
RestJson(RestJson1Trait.builder().build()),
RestXml(RestXmlTrait.builder().build()),
Rpcv2Cbor(Rpcv2CborTrait.builder().build()),
}
/**
* Returns the Smithy constraints model from the common repository, with the specified protocol
* applied to the service.
*/
fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair<Model, ShapeId> {
val (model, serviceShapeId) = loadSmithyConstraintsModel()
return Pair(model.replaceProtocolTrait(serviceShapeId, modelProtocol), serviceShapeId)
}
/**
* Loads the Smithy constraints model defined in the common repository and returns the model along with
* the service shape defined in it.
*/
fun loadSmithyConstraintsModel(): Pair<Model, ShapeId> {
val filePath = "../codegen-core/common-test-models/constraints.smithy"
val model =
File(filePath).readText().asSmithyModel()
val serviceShapeId = model.shapes().filter { it.isServiceShape }.findFirst().orElseThrow().id
return Pair(model, serviceShapeId)
}
/**
* Removes all existing protocol traits annotated on the given service,
* then sets the provided `protocol` as the sole protocol trait for the service.
*/
fun Model.replaceProtocolTrait(
serviceShapeId: ShapeId,
modelProtocol: ModelProtocol,
): Model {
val serviceBuilder =
this.expectShape(serviceShapeId, ServiceShape::class.java).toBuilder()
for (p in ModelProtocol.values()) {
serviceBuilder.removeTrait(p.trait.toShapeId())
}
val service = serviceBuilder.addTrait(modelProtocol.trait).build()
return ModelTransformer.create().replaceShapes(this, listOf(service))
}
fun List<ShapeId>.containsAnyShapeId(ids: Collection<ShapeId>): Boolean {
return ids.any { id -> this.any { shape -> shape == id } }
}
/**
* Removes the given operations from the model.
*/
fun Model.removeOperations(
serviceShapeId: ShapeId,
operationsToRemove: List<ShapeId>,
): Model {
val service = this.expectShape(serviceShapeId, ServiceShape::class.java)
val serviceBuilder = service.toBuilder()
// The operation must exist in the service.
service.operations.map { it.toShapeId() }.containsAll(operationsToRemove) shouldBe true
// Remove all operations.
for (opToRemove in operationsToRemove) {
serviceBuilder.removeOperation(opToRemove)
}
val changedModel = ModelTransformer.create().replaceShapes(this, listOf(serviceBuilder.build()))
// The operation must not exist in the updated service.
val changedService = changedModel.expectShape(serviceShapeId, ServiceShape::class.java)
changedService.operations.size shouldBeGreaterThan 0
changedService.operations.map { it.toShapeId() }.containsAnyShapeId(operationsToRemove) shouldBe false
return changedModel
}
class ConstraintsTest {
private val model =

View File

@ -0,0 +1,28 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings
import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol
import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
class CborConstraintsIntegrationTest {
@Test
fun `ensure CBOR implementation works for all constraint types`() {
val (model, serviceShape) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor)
// The test should compile; no further testing is required.
serverIntegrationTest(
model,
IntegrationTestParams(
service = serviceShape.toString(),
additionalSettings = ServerAdditionalSettings.builder().generateCodegenComments().toObjectNode(),
),
) { _, _ ->
}
}
}

View File

@ -32,7 +32,7 @@ pub fn parse_error_metadata(
_response_headers: &Headers,
response_body: &[u8],
) -> Result<ErrorMetadataBuilder, DeserializeError> {
fn error_code(
fn error_code_and_message(
mut builder: ErrorMetadataBuilder,
decoder: &mut Decoder,
) -> Result<ErrorMetadataBuilder, DeserializeError> {
@ -41,6 +41,14 @@ pub fn parse_error_metadata(
let code = decoder.str()?;
builder.code(sanitize_error_code(&code))
}
"message" | "Message" | "errorMessage" => {
// Silently skip if `message` is not a string. This allows for custom error
// structures that might use different types for the message field.
match decoder.str() {
Ok(message) => builder.message(message),
Err(_) => builder,
}
}
_ => {
decoder.skip()?;
builder
@ -60,13 +68,13 @@ pub fn parse_error_metadata(
break;
}
_ => {
builder = error_code(builder, decoder)?;
builder = error_code_and_message(builder, decoder)?;
}
};
},
Some(n) => {
for _ in 0..n {
builder = error_code(builder, decoder)?;
builder = error_code_and_message(builder, decoder)?;
}
}
};