mirror of https://github.com/smithy-lang/smithy-rs
Fix constraint-related errors in Rpcv2CBOR server implementation (#3794)
This commit is contained in:
commit
191c5771e3
|
@ -0,0 +1,9 @@
|
|||
---
|
||||
applies_to: ["server","client"]
|
||||
authors: ["drganjoo"]
|
||||
references: [smithy-rs#3573]
|
||||
breaking: false
|
||||
new_feature: true
|
||||
bug_fix: false
|
||||
---
|
||||
Support for the [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) protocol has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission.
|
|
@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource
|
|||
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.testModule
|
||||
|
@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
|
|||
"exception",
|
||||
"UnmodeledError",
|
||||
"${testCase.responseContentType}",
|
||||
br#"${testCase.validUnmodeledError}"#
|
||||
${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)}
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
|
|
|
@ -25,6 +25,7 @@ dependencies {
|
|||
implementation("org.jsoup:jsoup:1.16.2")
|
||||
api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
|
||||
api("com.moandjiezana.toml:toml4j:0.7.2")
|
||||
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
|
||||
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
|
||||
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
|
||||
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
|
||||
|
|
|
@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
|
|||
import software.amazon.smithy.model.shapes.ToShapeId
|
||||
import software.amazon.smithy.model.traits.HttpTrait
|
||||
import software.amazon.smithy.model.traits.TimestampFormatTrait
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
|
@ -140,9 +141,23 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
|
|||
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
|
||||
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")
|
||||
|
||||
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
|
||||
override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
|
||||
TODO("rpcv2Cbor event streams have not yet been implemented")
|
||||
ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName ->
|
||||
rustTemplate(
|
||||
"""
|
||||
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> {
|
||||
#{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload)
|
||||
}
|
||||
""",
|
||||
"cbor_errors" to RuntimeType.cborErrors(runtimeConfig),
|
||||
"Bytes" to RuntimeType.Bytes,
|
||||
"ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
|
||||
"DeserializeError" to
|
||||
CargoDependency.smithyCbor(runtimeConfig).toType()
|
||||
.resolve("decode::DeserializeError"),
|
||||
"Headers" to RuntimeType.headers(runtimeConfig),
|
||||
)
|
||||
}
|
||||
|
||||
// Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
|
||||
// unless there is no input or if the operation is an event stream, see
|
||||
|
|
|
@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
|
||||
import software.amazon.smithy.rust.codegen.core.util.PANIC
|
||||
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
|
||||
import software.amazon.smithy.rust.codegen.core.util.dq
|
||||
import software.amazon.smithy.rust.codegen.core.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.inputShape
|
||||
|
@ -447,7 +446,24 @@ class CborParserGenerator(
|
|||
}
|
||||
|
||||
override fun payloadParser(member: MemberShape): RuntimeType {
|
||||
UNREACHABLE("No protocol using CBOR serialization supports payload binding")
|
||||
val shape = model.expectShape(member.target)
|
||||
val returnSymbol = returnSymbolToParse(shape)
|
||||
check(shape is UnionShape || shape is StructureShape) {
|
||||
"Payload parser should only be used on structure and union shapes."
|
||||
}
|
||||
return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
|
||||
rustTemplate(
|
||||
"""
|
||||
pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
|
||||
let decoder = &mut #{Decoder}::new(value);
|
||||
#{DeserializeMember}
|
||||
}
|
||||
""",
|
||||
"ReturnType" to returnSymbol.symbol,
|
||||
"DeserializeMember" to deserializeMember(member),
|
||||
*codegenScope,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun operationParser(operationShape: OperationShape): RuntimeType? {
|
||||
|
|
|
@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) {
|
|||
/** Manipulate the serializer context for a map prior to it being serialized. **/
|
||||
data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) :
|
||||
CborSerializerSection("BeforeIteratingOverMapOrCollection")
|
||||
|
||||
/** Manipulate the serializer context for a non-null member prior to it being serialized. **/
|
||||
data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) :
|
||||
CborSerializerSection("BeforeSerializingNonNullMember")
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -200,9 +204,26 @@ class CborSerializerGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
|
||||
override fun payloadSerializer(member: MemberShape): RuntimeType {
|
||||
TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
|
||||
val target = model.expectShape(member.target)
|
||||
return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
|
||||
rustBlockTemplate(
|
||||
"pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
|
||||
*codegenScope,
|
||||
"target" to symbolProvider.toSymbol(target),
|
||||
) {
|
||||
rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
|
||||
rustBlock("") {
|
||||
rust("let encoder = &mut encoder;")
|
||||
when (target) {
|
||||
is StructureShape -> serializeStructure(StructContext("input", target))
|
||||
is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
|
||||
else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
|
||||
}
|
||||
}
|
||||
rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun unsetStructure(structure: StructureShape): RuntimeType =
|
||||
|
@ -311,6 +332,7 @@ class CborSerializerGenerator(
|
|||
safeName().also { local ->
|
||||
rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
|
||||
context.valueExpression = ValueExpression.Reference(local)
|
||||
resolveValueExpressionForConstrainedType(targetShape, context)
|
||||
serializeMemberValue(context, targetShape)
|
||||
}
|
||||
if (context.writeNulls) {
|
||||
|
@ -320,6 +342,7 @@ class CborSerializerGenerator(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
resolveValueExpressionForConstrainedType(targetShape, context)
|
||||
with(serializerUtil) {
|
||||
ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) {
|
||||
serializeMemberValue(context, targetShape)
|
||||
|
@ -328,6 +351,20 @@ class CborSerializerGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.resolveValueExpressionForConstrainedType(
|
||||
targetShape: Shape,
|
||||
context: MemberContext,
|
||||
) {
|
||||
for (customization in customizations) {
|
||||
customization.section(
|
||||
CborSerializerSection.BeforeSerializingNonNullMember(
|
||||
targetShape,
|
||||
context,
|
||||
),
|
||||
)(this)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RustWriter.serializeMemberValue(
|
||||
context: MemberContext,
|
||||
target: Shape,
|
||||
|
@ -362,7 +399,7 @@ class CborSerializerGenerator(
|
|||
rust("$encoder;") // Encode the member key.
|
||||
}
|
||||
when (target) {
|
||||
is StructureShape -> serializeStructure(StructContext(value.name, target))
|
||||
is StructureShape -> serializeStructure(StructContext(value.asRef(), target))
|
||||
is CollectionShape -> serializeCollection(Context(value, target))
|
||||
is MapShape -> serializeMap(Context(value, target))
|
||||
is UnionShape -> serializeUnion(Context(value, target))
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
package software.amazon.smithy.rust.codegen.core.testutil
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
|
||||
|
@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
|
||||
import java.util.Base64
|
||||
|
||||
private fun fillInBaseModel(
|
||||
protocolName: String,
|
||||
namespacedProtocolName: String,
|
||||
extraServiceAnnotations: String = "",
|
||||
): String =
|
||||
"""
|
||||
namespace test
|
||||
|
||||
use smithy.framework#ValidationException
|
||||
use aws.protocols#$protocolName
|
||||
use $namespacedProtocolName
|
||||
|
||||
union TestUnion {
|
||||
Foo: String,
|
||||
|
@ -86,22 +90,24 @@ private fun fillInBaseModel(
|
|||
}
|
||||
|
||||
$extraServiceAnnotations
|
||||
@$protocolName
|
||||
@${namespacedProtocolName.substringAfter("#")}
|
||||
service TestService { version: "123", operations: [TestStreamOp] }
|
||||
"""
|
||||
|
||||
object EventStreamTestModels {
|
||||
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
|
||||
private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()
|
||||
|
||||
private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
|
||||
private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()
|
||||
|
||||
private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
|
||||
private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()
|
||||
|
||||
private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()
|
||||
|
||||
private fun awsQuery(): Model =
|
||||
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
|
||||
fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
|
||||
|
||||
private fun ec2Query(): Model =
|
||||
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
|
||||
fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
|
||||
|
||||
data class TestCase(
|
||||
val protocolShapeId: String,
|
||||
|
@ -120,39 +126,67 @@ object EventStreamTestModels {
|
|||
override fun toString(): String = protocolShapeId
|
||||
}
|
||||
|
||||
private fun base64Encode(input: ByteArray): String {
|
||||
val encodedBytes = Base64.getEncoder().encode(input)
|
||||
return String(encodedBytes)
|
||||
}
|
||||
|
||||
private fun createCborFromJson(jsonString: String): ByteArray {
|
||||
val jsonMapper = ObjectMapper()
|
||||
val cborMapper = ObjectMapper(CBORFactory())
|
||||
// Parse JSON string to a generic type.
|
||||
val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
|
||||
// Convert the parsed data to CBOR.
|
||||
return cborMapper.writeValueAsBytes(jsonData)
|
||||
}
|
||||
|
||||
private val restJsonTestCase =
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#restJson1",
|
||||
model = restJson1(),
|
||||
mediaType = "application/json",
|
||||
requestContentType = "application/vnd.amazon.eventstream",
|
||||
responseContentType = "application/json",
|
||||
eventStreamMessageContentType = "application/json",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
|
||||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
validUnmodeledError = """{"Message":"unmodeled error"}""",
|
||||
) { RestJson(it) }
|
||||
|
||||
val TEST_CASES =
|
||||
listOf(
|
||||
//
|
||||
// restJson1
|
||||
//
|
||||
TestCase(
|
||||
protocolShapeId = "aws.protocols#restJson1",
|
||||
model = restJson1(),
|
||||
mediaType = "application/json",
|
||||
requestContentType = "application/vnd.amazon.eventstream",
|
||||
responseContentType = "application/json",
|
||||
eventStreamMessageContentType = "application/json",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
|
||||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
validUnmodeledError = """{"Message":"unmodeled error"}""",
|
||||
) { RestJson(it) },
|
||||
restJsonTestCase,
|
||||
//
|
||||
// rpcV2Cbor
|
||||
//
|
||||
restJsonTestCase.copy(
|
||||
protocolShapeId = "smithy.protocols#rpcv2Cbor",
|
||||
model = rpcv2Cbor(),
|
||||
mediaType = "application/cbor",
|
||||
responseContentType = "application/cbor",
|
||||
eventStreamMessageContentType = "application/cbor",
|
||||
validTestStruct = base64Encode(createCborFromJson(restJsonTestCase.validTestStruct)),
|
||||
validMessageWithNoHeaderPayloadTraits = base64Encode(createCborFromJson(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
|
||||
validTestUnion = base64Encode(createCborFromJson(restJsonTestCase.validTestUnion)),
|
||||
validSomeError = base64Encode(createCborFromJson(restJsonTestCase.validSomeError)),
|
||||
validUnmodeledError = base64Encode(createCborFromJson(restJsonTestCase.validUnmodeledError)),
|
||||
protocolBuilder = { RpcV2Cbor(it) },
|
||||
),
|
||||
//
|
||||
// awsJson1_1
|
||||
//
|
||||
TestCase(
|
||||
restJsonTestCase.copy(
|
||||
protocolShapeId = "aws.protocols#awsJson1_1",
|
||||
model = awsJson11(),
|
||||
mediaType = "application/x-amz-json-1.1",
|
||||
requestContentType = "application/x-amz-json-1.1",
|
||||
responseContentType = "application/x-amz-json-1.1",
|
||||
eventStreamMessageContentType = "application/json",
|
||||
validTestStruct = """{"someString":"hello","someInt":5}""",
|
||||
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
|
||||
validTestUnion = """{"Foo":"hello"}""",
|
||||
validSomeError = """{"Message":"some error"}""",
|
||||
validUnmodeledError = """{"Message":"unmodeled error"}""",
|
||||
) { AwsJson(it, AwsJsonVersion.Json11) },
|
||||
//
|
||||
// restXml
|
||||
|
|
|
@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.util.lookup
|
||||
import java.util.Base64
|
||||
|
||||
object EventStreamUnmarshallTestCases {
|
||||
fun RustWriter.writeUnmarshallTestCases(
|
||||
|
@ -109,7 +110,7 @@ object EventStreamUnmarshallTestCases {
|
|||
"event",
|
||||
"MessageWithStruct",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validTestStruct}"##
|
||||
${testCase.generateRustPayloadInitializer(testCase.validTestStruct)}
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
|
@ -140,7 +141,7 @@ object EventStreamUnmarshallTestCases {
|
|||
"event",
|
||||
"MessageWithUnion",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validTestUnion}"##
|
||||
${testCase.generateRustPayloadInitializer(testCase.validTestUnion)}
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
|
@ -221,7 +222,7 @@ object EventStreamUnmarshallTestCases {
|
|||
"event",
|
||||
"MessageWithNoHeaderPayloadTraits",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validMessageWithNoHeaderPayloadTraits}"##
|
||||
${testCase.generateRustPayloadInitializer(testCase.validMessageWithNoHeaderPayloadTraits)}
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
|
@ -246,7 +247,7 @@ object EventStreamUnmarshallTestCases {
|
|||
"exception",
|
||||
"SomeError",
|
||||
"${testCase.responseContentType}",
|
||||
br##"${testCase.validSomeError}"##
|
||||
${testCase.generateRustPayloadInitializer(testCase.validSomeError)}
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_ok(), "expected ok, got: {:?}", result);
|
||||
|
@ -267,7 +268,7 @@ object EventStreamUnmarshallTestCases {
|
|||
"event",
|
||||
"MessageWithBlob",
|
||||
"wrong-content-type",
|
||||
br#"${testCase.validTestStruct}"#
|
||||
${testCase.generateRustPayloadInitializer(testCase.validTestStruct)}
|
||||
);
|
||||
let result = $generator::new().unmarshall(&message);
|
||||
assert!(result.is_err(), "expected error, got: {:?}", result);
|
||||
|
@ -275,6 +276,35 @@ object EventStreamUnmarshallTestCases {
|
|||
""",
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a Rust-compatible initializer string for a given payload.
|
||||
*
|
||||
* This function handles two different scenarios based on the event stream message content type:
|
||||
*
|
||||
* 1. For CBOR payloads (content type "application/cbor"):
|
||||
* - The input payload is expected to be a base64 encoded CBOR value.
|
||||
* - It decodes the base64 string and generates a Rust byte array initializer.
|
||||
* - The output format is: &[0xFFu8, 0xFFu8, ...] where FF are hexadecimal values.
|
||||
*
|
||||
* 2. For all other content types:
|
||||
* - It returns a Rust raw string literal initializer.
|
||||
* - The output format is: br##"original_payload"##
|
||||
*/
|
||||
fun EventStreamTestModels.TestCase.generateRustPayloadInitializer(payload: String): String {
|
||||
return if (this.eventStreamMessageContentType == "application/cbor") {
|
||||
Base64.getDecoder().decode(payload)
|
||||
.joinToString(
|
||||
prefix = "&[",
|
||||
postfix = "]",
|
||||
transform = { "0x${it.toUByte().toString(16).padStart(2, '0')}u8" },
|
||||
)
|
||||
} else {
|
||||
"""
|
||||
br##"$payload"##
|
||||
"""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun conditionalBuilderInput(
|
||||
|
|
|
@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali
|
|||
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger
|
||||
import java.util.logging.Logger
|
||||
|
||||
|
@ -208,6 +209,8 @@ open class ServerCodegenVisitor(
|
|||
.let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) }
|
||||
// Tag aggregate shapes reachable from operation input
|
||||
.let(ShapesReachableFromOperationInputTagger::transform)
|
||||
// Remove traits that are not supported by the chosen protocol.
|
||||
.let { ServerProtocolBasedTransformationFactory.transform(it, settings) }
|
||||
// Normalize event stream operations
|
||||
.let(EventStreamNormalizer::transform)
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
package software.amazon.smithy.rust.codegen.server.smithy.customizations
|
||||
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.shapes.ByteShape
|
||||
import software.amazon.smithy.model.shapes.IntegerShape
|
||||
import software.amazon.smithy.model.shapes.LongShape
|
||||
import software.amazon.smithy.model.shapes.ShortShape
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType
|
||||
|
||||
/**
|
||||
* Constrained shapes are wrapped in a Rust tuple struct that implements all necessary checks. However,
|
||||
* for serialization purposes, the inner type of the constrained shape is used for serialization.
|
||||
*
|
||||
* The `BeforeSerializingMemberCborCustomization` class generates a reference to the inner type when the shape being
|
||||
* code-generated is constrained and the `publicConstrainedTypes` codegen flag is set.
|
||||
*/
|
||||
class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() {
|
||||
override fun section(section: CborSerializerSection): Writable =
|
||||
when (section) {
|
||||
is CborSerializerSection.BeforeSerializingNonNullMember ->
|
||||
writable {
|
||||
if (workingWithPublicConstrainedWrapperTupleType(
|
||||
section.shape,
|
||||
codegenContext.model,
|
||||
codegenContext.settings.codegenConfig.publicConstrainedTypes,
|
||||
)
|
||||
) {
|
||||
if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) {
|
||||
section.context.valueExpression =
|
||||
ValueExpression.Reference("&${section.context.valueExpression.name}.0")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
|
@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
|
|||
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
|
||||
|
@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol(
|
|||
listOf(
|
||||
BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext),
|
||||
AddTypeFieldToServerErrorsCborCustomization(),
|
||||
BeforeSerializingMemberCborCustomization(serverCodegenContext),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
package software.amazon.smithy.rust.codegen.server.smithy.transformers
|
||||
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.AbstractShapeBuilder
|
||||
import software.amazon.smithy.model.shapes.BlobShape
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.Shape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.traits.HttpLabelTrait
|
||||
import software.amazon.smithy.model.traits.HttpPayloadTrait
|
||||
import software.amazon.smithy.model.traits.HttpTrait
|
||||
import software.amazon.smithy.model.traits.StreamingTrait
|
||||
import software.amazon.smithy.model.transform.ModelTransformer
|
||||
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.hasTrait
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
|
||||
import software.amazon.smithy.utils.SmithyBuilder
|
||||
import software.amazon.smithy.utils.ToSmithyBuilder
|
||||
|
||||
/**
|
||||
* Each protocol may not support all of the features that Smithy allows. For instance, `rpcv2Cbor`
|
||||
* does not support HTTP bindings other than `@httpError`. `ServerProtocolBasedTransformationFactory` is a factory
|
||||
* object that transforms the model and removes specific traits based on the protocol being instantiated.
|
||||
*
|
||||
* In the long term, this class will be removed, and each protocol should be resilient enough to ignore extra
|
||||
* traits that the model is annotated with. This will be addressed when we fix issue
|
||||
* [#2979](https://github.com/smithy-lang/smithy-rs/issues/2979).
|
||||
*/
|
||||
object ServerProtocolBasedTransformationFactory {
|
||||
fun transform(
|
||||
model: Model,
|
||||
settings: ServerRustSettings,
|
||||
): Model {
|
||||
val service = settings.getService(model)
|
||||
if (!service.hasTrait<Rpcv2CborTrait>()) {
|
||||
return model
|
||||
}
|
||||
|
||||
// `rpcv2Cbor` does not support:
|
||||
// 1. `@httpPayload` trait.
|
||||
// 2. `@httpLabel` trait.
|
||||
// 3. `@streaming` trait applied to a `Blob` (data streaming).
|
||||
return ModelTransformer.create().mapShapes(model) { shape ->
|
||||
when (shape) {
|
||||
is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID)
|
||||
is MemberShape -> {
|
||||
shape
|
||||
.removeTraitIfPresent(HttpLabelTrait.ID)
|
||||
.removeTraitIfPresent(HttpPayloadTrait.ID)
|
||||
}
|
||||
is BlobShape -> {
|
||||
shape.removeTraitIfPresent(StreamingTrait.ID)
|
||||
}
|
||||
else -> shape
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun <T : Shape, B> T.removeTraitIfPresent(
|
||||
traitId: ShapeId,
|
||||
): T
|
||||
where T : ToSmithyBuilder<T>,
|
||||
B : AbstractShapeBuilder<B, T>,
|
||||
B : SmithyBuilder<T> {
|
||||
return if (this.hasTrait(traitId)) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
(this.toBuilder() as B).removeTrait(traitId).build()
|
||||
} else {
|
||||
this
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,17 +6,103 @@
|
|||
package software.amazon.smithy.rust.codegen.server.smithy
|
||||
|
||||
import io.kotest.inspectors.forAll
|
||||
import io.kotest.matchers.ints.shouldBeGreaterThan
|
||||
import io.kotest.matchers.shouldBe
|
||||
import org.junit.jupiter.api.Test
|
||||
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
|
||||
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
|
||||
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
||||
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
||||
import software.amazon.smithy.model.Model
|
||||
import software.amazon.smithy.model.shapes.BooleanShape
|
||||
import software.amazon.smithy.model.shapes.ListShape
|
||||
import software.amazon.smithy.model.shapes.MapShape
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.ServiceShape
|
||||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.model.shapes.StringShape
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.model.traits.AbstractTrait
|
||||
import software.amazon.smithy.model.transform.ModelTransformer
|
||||
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
|
||||
import software.amazon.smithy.rust.codegen.core.util.lookup
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
|
||||
import java.io.File
|
||||
|
||||
enum class ModelProtocol(val trait: AbstractTrait) {
|
||||
AwsJson10(AwsJson1_0Trait.builder().build()),
|
||||
AwsJson11(AwsJson1_1Trait.builder().build()),
|
||||
RestJson(RestJson1Trait.builder().build()),
|
||||
RestXml(RestXmlTrait.builder().build()),
|
||||
Rpcv2Cbor(Rpcv2CborTrait.builder().build()),
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the Smithy constraints model from the common repository, with the specified protocol
|
||||
* applied to the service.
|
||||
*/
|
||||
fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair<Model, ShapeId> {
|
||||
val (model, serviceShapeId) = loadSmithyConstraintsModel()
|
||||
return Pair(model.replaceProtocolTrait(serviceShapeId, modelProtocol), serviceShapeId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the Smithy constraints model defined in the common repository and returns the model along with
|
||||
* the service shape defined in it.
|
||||
*/
|
||||
fun loadSmithyConstraintsModel(): Pair<Model, ShapeId> {
|
||||
val filePath = "../codegen-core/common-test-models/constraints.smithy"
|
||||
val model =
|
||||
File(filePath).readText().asSmithyModel()
|
||||
val serviceShapeId = model.shapes().filter { it.isServiceShape }.findFirst().orElseThrow().id
|
||||
return Pair(model, serviceShapeId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes all existing protocol traits annotated on the given service,
|
||||
* then sets the provided `protocol` as the sole protocol trait for the service.
|
||||
*/
|
||||
fun Model.replaceProtocolTrait(
|
||||
serviceShapeId: ShapeId,
|
||||
modelProtocol: ModelProtocol,
|
||||
): Model {
|
||||
val serviceBuilder =
|
||||
this.expectShape(serviceShapeId, ServiceShape::class.java).toBuilder()
|
||||
for (p in ModelProtocol.values()) {
|
||||
serviceBuilder.removeTrait(p.trait.toShapeId())
|
||||
}
|
||||
val service = serviceBuilder.addTrait(modelProtocol.trait).build()
|
||||
return ModelTransformer.create().replaceShapes(this, listOf(service))
|
||||
}
|
||||
|
||||
fun List<ShapeId>.containsAnyShapeId(ids: Collection<ShapeId>): Boolean {
|
||||
return ids.any { id -> this.any { shape -> shape == id } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the given operations from the model.
|
||||
*/
|
||||
fun Model.removeOperations(
|
||||
serviceShapeId: ShapeId,
|
||||
operationsToRemove: List<ShapeId>,
|
||||
): Model {
|
||||
val service = this.expectShape(serviceShapeId, ServiceShape::class.java)
|
||||
val serviceBuilder = service.toBuilder()
|
||||
// The operation must exist in the service.
|
||||
service.operations.map { it.toShapeId() }.containsAll(operationsToRemove) shouldBe true
|
||||
// Remove all operations.
|
||||
for (opToRemove in operationsToRemove) {
|
||||
serviceBuilder.removeOperation(opToRemove)
|
||||
}
|
||||
val changedModel = ModelTransformer.create().replaceShapes(this, listOf(serviceBuilder.build()))
|
||||
// The operation must not exist in the updated service.
|
||||
val changedService = changedModel.expectShape(serviceShapeId, ServiceShape::class.java)
|
||||
changedService.operations.size shouldBeGreaterThan 0
|
||||
changedService.operations.map { it.toShapeId() }.containsAnyShapeId(operationsToRemove) shouldBe false
|
||||
|
||||
return changedModel
|
||||
}
|
||||
|
||||
class ConstraintsTest {
|
||||
private val model =
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize
|
||||
|
||||
import org.junit.jupiter.api.Test
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol
|
||||
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
|
||||
|
||||
class CborConstraintsIntegrationTest {
|
||||
@Test
|
||||
fun `ensure CBOR implementation works for all constraint types`() {
|
||||
val (model, serviceShape) = loadSmithyConstraintsModelForProtocol(ModelProtocol.Rpcv2Cbor)
|
||||
// The test should compile; no further testing is required.
|
||||
serverIntegrationTest(
|
||||
model,
|
||||
IntegrationTestParams(
|
||||
service = serviceShape.toString(),
|
||||
additionalSettings = ServerAdditionalSettings.builder().generateCodegenComments().toObjectNode(),
|
||||
),
|
||||
) { _, _ ->
|
||||
}
|
||||
}
|
||||
}
|
|
@ -32,7 +32,7 @@ pub fn parse_error_metadata(
|
|||
_response_headers: &Headers,
|
||||
response_body: &[u8],
|
||||
) -> Result<ErrorMetadataBuilder, DeserializeError> {
|
||||
fn error_code(
|
||||
fn error_code_and_message(
|
||||
mut builder: ErrorMetadataBuilder,
|
||||
decoder: &mut Decoder,
|
||||
) -> Result<ErrorMetadataBuilder, DeserializeError> {
|
||||
|
@ -41,6 +41,14 @@ pub fn parse_error_metadata(
|
|||
let code = decoder.str()?;
|
||||
builder.code(sanitize_error_code(&code))
|
||||
}
|
||||
"message" | "Message" | "errorMessage" => {
|
||||
// Silently skip if `message` is not a string. This allows for custom error
|
||||
// structures that might use different types for the message field.
|
||||
match decoder.str() {
|
||||
Ok(message) => builder.message(message),
|
||||
Err(_) => builder,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
decoder.skip()?;
|
||||
builder
|
||||
|
@ -60,13 +68,13 @@ pub fn parse_error_metadata(
|
|||
break;
|
||||
}
|
||||
_ => {
|
||||
builder = error_code(builder, decoder)?;
|
||||
builder = error_code_and_message(builder, decoder)?;
|
||||
}
|
||||
};
|
||||
},
|
||||
Some(n) => {
|
||||
for _ in 0..n {
|
||||
builder = error_code(builder, decoder)?;
|
||||
builder = error_code_and_message(builder, decoder)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue