mirror of https://github.com/smithy-lang/smithy-rs
Serialize Request bodies for restJson (#255)
* Serialize Request Bodies for RestJson * Move the bodies to the serializer module * RestJson body CR feedback and bug fixes
This commit is contained in:
parent
4f958f3709
commit
41d948c597
|
@ -36,6 +36,10 @@ val CodegenTests = listOf(
|
||||||
"aws.protocoltests.restjson#RestJson",
|
"aws.protocoltests.restjson#RestJson",
|
||||||
"rest_json"
|
"rest_json"
|
||||||
),
|
),
|
||||||
|
CodegenTest(
|
||||||
|
"aws.protocoltests.restjson#RestJsonExtras",
|
||||||
|
"rest_json_extas"
|
||||||
|
),
|
||||||
CodegenTest(
|
CodegenTest(
|
||||||
"crate#Config",
|
"crate#Config",
|
||||||
"naming_test", """
|
"naming_test", """
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
$version: "1.0"
|
||||||
|
|
||||||
|
namespace aws.protocoltests.restjson
|
||||||
|
|
||||||
|
use aws.protocols#restJson1
|
||||||
|
use aws.api#service
|
||||||
|
use smithy.test#httpRequestTests
|
||||||
|
|
||||||
|
|
||||||
|
/// A REST JSON service that sends JSON requests and responses.
|
||||||
|
@service(sdkId: "Rest Json Protocol")
|
||||||
|
@restJson1
|
||||||
|
service RestJsonExtras {
|
||||||
|
version: "2019-12-16",
|
||||||
|
operations: [EnumPayload, StringPayload]
|
||||||
|
}
|
||||||
|
|
||||||
|
@http(uri: "/EnumPayload", method: "POST")
|
||||||
|
@httpRequestTests([
|
||||||
|
{
|
||||||
|
id: "EnumPayload",
|
||||||
|
uri: "/EnumPayload",
|
||||||
|
body: "enumvalue",
|
||||||
|
params: { payload: "enumvalue" },
|
||||||
|
method: "POST",
|
||||||
|
protocol: "aws.protocols#restJson1"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
operation EnumPayload {
|
||||||
|
input: EnumPayloadInput
|
||||||
|
}
|
||||||
|
|
||||||
|
structure EnumPayloadInput {
|
||||||
|
@httpPayload
|
||||||
|
payload: StringEnum
|
||||||
|
}
|
||||||
|
|
||||||
|
@enum([{"value": "enumvalue", "name": "V"}])
|
||||||
|
string StringEnum
|
||||||
|
|
||||||
|
@http(uri: "/StringPayload", method: "POST")
|
||||||
|
@httpRequestTests([
|
||||||
|
{
|
||||||
|
id: "StringPayload",
|
||||||
|
uri: "/StringPayload",
|
||||||
|
body: "rawstring",
|
||||||
|
params: { payload: "rawstring" },
|
||||||
|
method: "POST",
|
||||||
|
protocol: "aws.protocols#restJson1"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
operation StringPayload {
|
||||||
|
input: StringPayloadInput
|
||||||
|
}
|
||||||
|
|
||||||
|
structure StringPayloadInput {
|
||||||
|
@httpPayload
|
||||||
|
payload: String
|
||||||
|
}
|
|
@ -40,6 +40,8 @@ import software.amazon.smithy.model.traits.HttpLabelTrait
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.RustType
|
import software.amazon.smithy.rust.codegen.rustlang.RustType
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.Writable
|
import software.amazon.smithy.rust.codegen.rustlang.Writable
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
|
||||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
||||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
|
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
|
||||||
import software.amazon.smithy.rust.codegen.util.toSnakeCase
|
import software.amazon.smithy.rust.codegen.util.toSnakeCase
|
||||||
|
@ -231,6 +233,7 @@ class SymbolVisitor(
|
||||||
val isError = shape.hasTrait(ErrorTrait::class.java)
|
val isError = shape.hasTrait(ErrorTrait::class.java)
|
||||||
val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
|
val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
|
||||||
val isOutput = shape.hasTrait(SyntheticOutputTrait::class.java)
|
val isOutput = shape.hasTrait(SyntheticOutputTrait::class.java)
|
||||||
|
val isBody = shape.hasTrait(InputBodyTrait::class.java) || shape.hasTrait(OutputBodyTrait::class.java)
|
||||||
val name = StringUtils.capitalize(shape.id.name).letIf(isError && config.codegenConfig.renameExceptions) {
|
val name = StringUtils.capitalize(shape.id.name).letIf(isError && config.codegenConfig.renameExceptions) {
|
||||||
// TODO: Do we want to do this?
|
// TODO: Do we want to do this?
|
||||||
// https://github.com/awslabs/smithy-rs/issues/77
|
// https://github.com/awslabs/smithy-rs/issues/77
|
||||||
|
@ -241,6 +244,7 @@ class SymbolVisitor(
|
||||||
isError -> builder.locatedIn(Errors)
|
isError -> builder.locatedIn(Errors)
|
||||||
isInput -> builder.locatedIn(Inputs)
|
isInput -> builder.locatedIn(Inputs)
|
||||||
isOutput -> builder.locatedIn(Outputs)
|
isOutput -> builder.locatedIn(Outputs)
|
||||||
|
isBody -> builder.locatedIn(Serializers)
|
||||||
else -> builder.locatedIn(Models)
|
else -> builder.locatedIn(Models)
|
||||||
}.build()
|
}.build()
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,13 +49,24 @@ abstract class HttpProtocolGenerator(
|
||||||
) {
|
) {
|
||||||
private val symbolProvider = protocolConfig.symbolProvider
|
private val symbolProvider = protocolConfig.symbolProvider
|
||||||
private val model = protocolConfig.model
|
private val model = protocolConfig.model
|
||||||
fun renderOperation(operationWriter: RustWriter, inputWriter: RustWriter, operationShape: OperationShape, customizations: List<OperationCustomization>) {
|
fun renderOperation(
|
||||||
|
operationWriter: RustWriter,
|
||||||
|
inputWriter: RustWriter,
|
||||||
|
operationShape: OperationShape,
|
||||||
|
customizations: List<OperationCustomization>
|
||||||
|
) {
|
||||||
/* if (operationShape.hasTrait(EndpointTrait::class.java)) {
|
/* if (operationShape.hasTrait(EndpointTrait::class.java)) {
|
||||||
TODO("https://github.com/awslabs/smithy-rs/issues/197")
|
TODO("https://github.com/awslabs/smithy-rs/issues/197")
|
||||||
} */
|
} */
|
||||||
val inputShape = operationShape.inputShape(model)
|
val inputShape = operationShape.inputShape(model)
|
||||||
val inputSymbol = symbolProvider.toSymbol(inputShape)
|
val inputSymbol = symbolProvider.toSymbol(inputShape)
|
||||||
val builderGenerator = OperationInputBuilderGenerator(model, symbolProvider, operationShape, protocolConfig.moduleName, customizations)
|
val builderGenerator = OperationInputBuilderGenerator(
|
||||||
|
model,
|
||||||
|
symbolProvider,
|
||||||
|
operationShape,
|
||||||
|
protocolConfig.moduleName,
|
||||||
|
customizations
|
||||||
|
)
|
||||||
builderGenerator.render(inputWriter)
|
builderGenerator.render(inputWriter)
|
||||||
// impl OperationInputShape { ... }
|
// impl OperationInputShape { ... }
|
||||||
|
|
||||||
|
@ -63,7 +74,7 @@ abstract class HttpProtocolGenerator(
|
||||||
toHttpRequestImpl(this, operationShape, inputShape)
|
toHttpRequestImpl(this, operationShape, inputShape)
|
||||||
val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body
|
val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body
|
||||||
val body = shapeId?.let { model.expectShape(it, StructureShape::class.java) }
|
val body = shapeId?.let { model.expectShape(it, StructureShape::class.java) }
|
||||||
toBodyImpl(this, inputShape, body)
|
toBodyImpl(this, inputShape, body, operationShape)
|
||||||
// TODO: streaming shapes need special support
|
// TODO: streaming shapes need special support
|
||||||
rustBlock(
|
rustBlock(
|
||||||
"pub fn assemble(builder: #1T, body: #3T) -> #2T<#3T>",
|
"pub fn assemble(builder: #1T, body: #3T) -> #2T<#3T>",
|
||||||
|
@ -130,14 +141,18 @@ abstract class HttpProtocolGenerator(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected fun fromResponseFun(implBlockWriter: RustWriter, operationShape: OperationShape, f: RustWriter.() -> Unit) {
|
protected fun fromResponseFun(
|
||||||
|
implBlockWriter: RustWriter,
|
||||||
|
operationShape: OperationShape,
|
||||||
|
block: RustWriter.() -> Unit
|
||||||
|
) {
|
||||||
implBlockWriter.rustBlock(
|
implBlockWriter.rustBlock(
|
||||||
"fn from_response(response: &#T<impl AsRef<[u8]>>) -> Result<#T, #T>",
|
"fn from_response(response: &#T<impl AsRef<[u8]>>) -> Result<#T, #T>",
|
||||||
RuntimeType.Http("response::Response"),
|
RuntimeType.Http("response::Response"),
|
||||||
symbolProvider.toSymbol(operationShape.outputShape(model)),
|
symbolProvider.toSymbol(operationShape.outputShape(model)),
|
||||||
operationShape.errorSymbol(symbolProvider)
|
operationShape.errorSymbol(symbolProvider)
|
||||||
) {
|
) {
|
||||||
f(this)
|
block(this)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,12 +165,21 @@ abstract class HttpProtocolGenerator(
|
||||||
*
|
*
|
||||||
* Your implementation MUST call [bodyBuilderFun] to create the public method.
|
* Your implementation MUST call [bodyBuilderFun] to create the public method.
|
||||||
*/
|
*/
|
||||||
abstract fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?)
|
abstract fun toBodyImpl(
|
||||||
|
implBlockWriter: RustWriter,
|
||||||
|
inputShape: StructureShape,
|
||||||
|
inputBody: StructureShape?,
|
||||||
|
operationShape: OperationShape
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add necessary methods to the impl block for the input shape.
|
* Add necessary methods to the impl block for the input shape.
|
||||||
*
|
*
|
||||||
* Your implementation MUST call [httpBuilderFun] to create the public method.
|
* Your implementation MUST call [httpBuilderFun] to create the public method.
|
||||||
*/
|
*/
|
||||||
abstract fun toHttpRequestImpl(implBlockWriter: RustWriter, operationShape: OperationShape, inputShape: StructureShape)
|
abstract fun toHttpRequestImpl(
|
||||||
|
implBlockWriter: RustWriter,
|
||||||
|
operationShape: OperationShape,
|
||||||
|
inputShape: StructureShape
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -172,7 +172,7 @@ class HttpProtocolTestGenerator(
|
||||||
checkRequiredHeaders(this, httpRequestTestCase.requireHeaders)
|
checkRequiredHeaders(this, httpRequestTestCase.requireHeaders)
|
||||||
if (protocolSupport.requestBodySerialization) {
|
if (protocolSupport.requestBodySerialization) {
|
||||||
// "If no request body is defined, then no assertions are made about the body of the message."
|
// "If no request body is defined, then no assertions are made about the body of the message."
|
||||||
httpRequestTestCase.body.orNull()?.let { body ->
|
httpRequestTestCase.body.orNull()?.also { body ->
|
||||||
checkBody(this, body, httpRequestTestCase.bodyMediaType.orNull())
|
checkBody(this, body, httpRequestTestCase.bodyMediaType.orNull())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,7 +253,7 @@ class HttpProtocolTestGenerator(
|
||||||
rustWriter.write("""let body = http_request.body().bytes().expect("body should be strict");""")
|
rustWriter.write("""let body = http_request.body().bytes().expect("body should be strict");""")
|
||||||
if (body == "") {
|
if (body == "") {
|
||||||
rustWriter.write("// No body")
|
rustWriter.write("// No body")
|
||||||
rustWriter.write("assert!(&body.is_empty());")
|
rustWriter.write("assert_eq!(std::str::from_utf8(body).unwrap(), ${"".dq()});")
|
||||||
} else {
|
} else {
|
||||||
// When we generate a body instead of a stub, drop the trailing `;` and enable the assertion
|
// When we generate a body instead of a stub, drop the trailing `;` and enable the assertion
|
||||||
assertOk(rustWriter) {
|
assertOk(rustWriter) {
|
||||||
|
@ -383,7 +383,9 @@ class HttpProtocolTestGenerator(
|
||||||
// or because they are flaky
|
// or because they are flaky
|
||||||
private val DisableTests = setOf(
|
private val DisableTests = setOf(
|
||||||
// This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37
|
// This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37
|
||||||
"AwsJson11Enums"
|
"AwsJson11Enums",
|
||||||
|
"RestJsonJsonEnums",
|
||||||
|
"RestJsonLists"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock
|
||||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
|
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
|
||||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||||
import software.amazon.smithy.rust.codegen.util.dq
|
import software.amazon.smithy.rust.codegen.util.dq
|
||||||
|
import software.amazon.smithy.rust.codegen.util.expectMember
|
||||||
|
|
||||||
fun HttpTrait.uriFormatString(): String {
|
fun HttpTrait.uriFormatString(): String {
|
||||||
val base = uri.segments.joinToString("/", prefix = "/") {
|
val base = uri.segments.joinToString("/", prefix = "/") {
|
||||||
|
@ -160,7 +161,7 @@ class HttpTraitBindingGenerator(
|
||||||
private fun uriBase(writer: RustWriter) {
|
private fun uriBase(writer: RustWriter) {
|
||||||
val formatString = httpTrait.uriFormatString()
|
val formatString = httpTrait.uriFormatString()
|
||||||
val args = httpTrait.uri.labels.map { label ->
|
val args = httpTrait.uri.labels.map { label ->
|
||||||
val member = inputShape.getMember(label.content).get()
|
val member = inputShape.expectMember(label.content)
|
||||||
"${label.content} = ${labelFmtFun(model.expectShape(member.target), member, label)}"
|
"${label.content} = ${labelFmtFun(model.expectShape(member.target), member, label)}"
|
||||||
}
|
}
|
||||||
val combinedArgs = listOf(formatString, *args.toTypedArray())
|
val combinedArgs = listOf(formatString, *args.toTypedArray())
|
||||||
|
|
|
@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional
|
||||||
import software.amazon.smithy.rust.codegen.smithy.rustType
|
import software.amazon.smithy.rust.codegen.smithy.rustType
|
||||||
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
|
||||||
import software.amazon.smithy.rust.codegen.util.dq
|
import software.amazon.smithy.rust.codegen.util.dq
|
||||||
|
import software.amazon.smithy.rust.codegen.util.expectMember
|
||||||
import software.amazon.smithy.rust.codegen.util.toPascalCase
|
import software.amazon.smithy.rust.codegen.util.toPascalCase
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -187,7 +188,7 @@ class Instantiator(
|
||||||
check(data.members.size == 1)
|
check(data.members.size == 1)
|
||||||
val variant = data.members.iterator().next()
|
val variant = data.members.iterator().next()
|
||||||
val memberName = variant.key.value
|
val memberName = variant.key.value
|
||||||
val member = shape.getMember(memberName).get()
|
val member = shape.expectMember(memberName)
|
||||||
.let { model.expectShape(it.target) }
|
.let { model.expectShape(it.target) }
|
||||||
// TODO: refactor this detail into UnionGenerator
|
// TODO: refactor this detail into UnionGenerator
|
||||||
writer.write("#T::${memberName.toPascalCase()}", unionSymbol)
|
writer.write("#T::${memberName.toPascalCase()}", unionSymbol)
|
||||||
|
@ -278,8 +279,7 @@ class Instantiator(
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun getMember(shape: StructureShape, key: StringNode): Pair<MemberShape, Shape> {
|
private fun getMember(shape: StructureShape, key: StringNode): Pair<MemberShape, Shape> {
|
||||||
val memberShape = shape.getMember(key.value)
|
val memberShape = shape.expectMember(key.value)
|
||||||
.orElseThrow { IllegalArgumentException("$shape did not have member ${key.value}") }
|
|
||||||
return memberShape to model.expectShape(memberShape.target)
|
return memberShape to model.expectShape(memberShape.target)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat
|
||||||
protocolConfig: ProtocolConfig
|
protocolConfig: ProtocolConfig
|
||||||
): BasicAwsJsonGenerator = BasicAwsJsonGenerator(protocolConfig, version)
|
): BasicAwsJsonGenerator = BasicAwsJsonGenerator(protocolConfig, version)
|
||||||
|
|
||||||
private val shapeIfHasMembers: StructureModifier = { shape: StructureShape? ->
|
private val shapeIfHasMembers: StructureModifier = { _, shape: StructureShape? ->
|
||||||
if (shape?.members().isNullOrEmpty()) {
|
if (shape?.members().isNullOrEmpty()) {
|
||||||
null
|
null
|
||||||
} else {
|
} else {
|
||||||
|
@ -178,7 +178,12 @@ class BasicAwsJsonGenerator(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?) {
|
override fun toBodyImpl(
|
||||||
|
implBlockWriter: RustWriter,
|
||||||
|
inputShape: StructureShape,
|
||||||
|
inputBody: StructureShape?,
|
||||||
|
operationShape: OperationShape
|
||||||
|
) {
|
||||||
if (inputBody == null) {
|
if (inputBody == null) {
|
||||||
bodyBuilderFun(implBlockWriter) {
|
bodyBuilderFun(implBlockWriter) {
|
||||||
write("\"{}\".to_string().into()")
|
write("\"{}\".to_string().into()")
|
||||||
|
|
|
@ -6,30 +6,64 @@
|
||||||
package software.amazon.smithy.rust.codegen.smithy.protocols
|
package software.amazon.smithy.rust.codegen.smithy.protocols
|
||||||
|
|
||||||
import software.amazon.smithy.model.Model
|
import software.amazon.smithy.model.Model
|
||||||
|
import software.amazon.smithy.model.knowledge.HttpBinding
|
||||||
import software.amazon.smithy.model.knowledge.HttpBindingIndex
|
import software.amazon.smithy.model.knowledge.HttpBindingIndex
|
||||||
|
import software.amazon.smithy.model.shapes.BlobShape
|
||||||
|
import software.amazon.smithy.model.shapes.DocumentShape
|
||||||
import software.amazon.smithy.model.shapes.OperationShape
|
import software.amazon.smithy.model.shapes.OperationShape
|
||||||
|
import software.amazon.smithy.model.shapes.Shape
|
||||||
|
import software.amazon.smithy.model.shapes.StringShape
|
||||||
import software.amazon.smithy.model.shapes.StructureShape
|
import software.amazon.smithy.model.shapes.StructureShape
|
||||||
|
import software.amazon.smithy.model.shapes.UnionShape
|
||||||
|
import software.amazon.smithy.model.traits.EnumTrait
|
||||||
import software.amazon.smithy.model.traits.HttpTrait
|
import software.amazon.smithy.model.traits.HttpTrait
|
||||||
|
import software.amazon.smithy.model.traits.TimestampFormatTrait
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.rust
|
import software.amazon.smithy.rust.codegen.rustlang.rust
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.writable
|
||||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator
|
import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.HttpTraitBindingGenerator
|
import software.amazon.smithy.rust.codegen.smithy.generators.HttpTraitBindingGenerator
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
|
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
|
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport
|
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.isOptional
|
||||||
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
|
||||||
import software.amazon.smithy.rust.codegen.util.dq
|
import software.amazon.smithy.rust.codegen.util.dq
|
||||||
|
import software.amazon.smithy.rust.codegen.util.expectMember
|
||||||
|
|
||||||
class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {
|
class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {
|
||||||
override fun buildProtocolGenerator(
|
override fun buildProtocolGenerator(
|
||||||
protocolConfig: ProtocolConfig
|
protocolConfig: ProtocolConfig
|
||||||
): AwsRestJsonGenerator = AwsRestJsonGenerator(protocolConfig)
|
): AwsRestJsonGenerator = AwsRestJsonGenerator(protocolConfig)
|
||||||
|
|
||||||
|
/** Create a synthetic awsJsonInputBody if specified
|
||||||
|
* A body is created iff no member of [input] is targeted with the `PAYLOAD` trait. If a member is targeted with
|
||||||
|
* the payload trait, we don't need to create an input body.
|
||||||
|
*/
|
||||||
|
private fun awsJsonInputBody(model: Model, operation: OperationShape, input: StructureShape?): StructureShape? {
|
||||||
|
if (input == null) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
val bindingIndex = HttpBindingIndex.of(model)
|
||||||
|
val bindings: MutableMap<String, HttpBinding> = bindingIndex.getRequestBindings(operation)
|
||||||
|
val bodyMembers = input.members().filter { member ->
|
||||||
|
bindings[member.memberName]?.location == HttpBinding.Location.DOCUMENT
|
||||||
|
}
|
||||||
|
|
||||||
|
return if (bodyMembers.isNotEmpty()) {
|
||||||
|
input.toBuilder().members(bodyMembers).build()
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun transformModel(model: Model): Model {
|
override fun transformModel(model: Model): Model {
|
||||||
// TODO: AWSRestJson determines the body from HTTP traits
|
|
||||||
return OperationNormalizer(model).transformModel(
|
return OperationNormalizer(model).transformModel(
|
||||||
inputBodyFactory = OperationNormalizer.NoBody,
|
inputBodyFactory = { op, input -> awsJsonInputBody(model, op, input) },
|
||||||
outputBodyFactory = OperationNormalizer.NoBody
|
outputBodyFactory = OperationNormalizer.NoBody
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -37,15 +71,23 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {
|
||||||
override fun support(): ProtocolSupport {
|
override fun support(): ProtocolSupport {
|
||||||
// TODO: Support body for RestJson
|
// TODO: Support body for RestJson
|
||||||
return ProtocolSupport(
|
return ProtocolSupport(
|
||||||
requestBodySerialization = false,
|
requestBodySerialization = true,
|
||||||
responseDeserialization = false,
|
responseDeserialization = false,
|
||||||
errorDeserialization = false
|
errorDeserialization = false
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun symbolProvider(model: Model, base: RustSymbolProvider): RustSymbolProvider {
|
||||||
|
return JsonSerializerSymbolProvider(
|
||||||
|
model,
|
||||||
|
SyntheticBodySymbolProvider(model, base),
|
||||||
|
TimestampFormatTrait.Format.EPOCH_SECONDS
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class AwsRestJsonGenerator(
|
class AwsRestJsonGenerator(
|
||||||
protocolConfig: ProtocolConfig
|
private val protocolConfig: ProtocolConfig
|
||||||
) : HttpProtocolGenerator(protocolConfig) {
|
) : HttpProtocolGenerator(protocolConfig) {
|
||||||
// restJson1 requires all operations to use the HTTP trait
|
// restJson1 requires all operations to use the HTTP trait
|
||||||
|
|
||||||
|
@ -66,9 +108,94 @@ class AwsRestJsonGenerator(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?) {
|
private fun serializeViaSyntheticBody(
|
||||||
|
implBlockWriter: RustWriter,
|
||||||
|
inputBody: StructureShape
|
||||||
|
) {
|
||||||
|
val bodySymbol = protocolConfig.symbolProvider.toSymbol(inputBody)
|
||||||
|
implBlockWriter.rustBlock("fn body(&self) -> #T", bodySymbol) {
|
||||||
|
rustBlock("#T", bodySymbol) {
|
||||||
|
for (member in inputBody.members()) {
|
||||||
|
val name = protocolConfig.symbolProvider.toMemberName(member)
|
||||||
|
write("$name: &self.$name,")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bodyBuilderFun(implBlockWriter) {
|
bodyBuilderFun(implBlockWriter) {
|
||||||
rust(""""body not generated yet".into()""")
|
write("""#T(&self.body()).expect("serialization should succeed")""", RuntimeType.SerdeJson("to_vec"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toBodyImpl(
|
||||||
|
implBlockWriter: RustWriter,
|
||||||
|
inputShape: StructureShape,
|
||||||
|
inputBody: StructureShape?,
|
||||||
|
operationShape: OperationShape
|
||||||
|
) {
|
||||||
|
// If we created a synthetic input body, serialize that
|
||||||
|
if (inputBody != null) {
|
||||||
|
return serializeViaSyntheticBody(implBlockWriter, inputBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we need to serialize via the HTTP payload trait
|
||||||
|
val bindings = httpIndex.getRequestBindings(operationShape).toList()
|
||||||
|
val payload: Pair<String, HttpBinding>? =
|
||||||
|
bindings.firstOrNull { (_, binding) -> binding.location == HttpBinding.Location.PAYLOAD }
|
||||||
|
val payloadSerde = payload?.let { (payloadMemberName, _) ->
|
||||||
|
val member = inputShape.expectMember(payloadMemberName)
|
||||||
|
val rustMemberName = "self.${symbolProvider.toMemberName(member)}"
|
||||||
|
val targetShape = model.expectShape(member.target)
|
||||||
|
writable {
|
||||||
|
val payloadName = safeName()
|
||||||
|
rust("let $payloadName = &$rustMemberName;")
|
||||||
|
// If this targets a member & the member is None, return an empty vec
|
||||||
|
if (symbolProvider.toSymbol(member).isOptional()) {
|
||||||
|
rust(
|
||||||
|
"""
|
||||||
|
let $payloadName = match $payloadName.as_ref() {
|
||||||
|
Some(t) => t,
|
||||||
|
None => return vec![]
|
||||||
|
};"""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
renderPayload(targetShape, payloadName)
|
||||||
|
}
|
||||||
|
// body is null, no payload set, so this is empty
|
||||||
|
} ?: writable { rust("vec![]") }
|
||||||
|
bodyBuilderFun(implBlockWriter) {
|
||||||
|
payloadSerde(this)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun RustWriter.renderPayload(
|
||||||
|
targetShape: Shape,
|
||||||
|
payloadName: String,
|
||||||
|
) {
|
||||||
|
val serdeToVec = RuntimeType.SerdeJson("to_vec")
|
||||||
|
when (targetShape) {
|
||||||
|
// Write the raw string to the payload
|
||||||
|
is StringShape ->
|
||||||
|
if (targetShape.hasTrait(EnumTrait::class.java)) {
|
||||||
|
rust("$payloadName.as_str().into()")
|
||||||
|
} else {
|
||||||
|
rust("""$payloadName.to_string().into()""")
|
||||||
|
}
|
||||||
|
is BlobShape ->
|
||||||
|
// Write the raw blob to the payload
|
||||||
|
rust("$payloadName.as_ref().into()")
|
||||||
|
is StructureShape, is UnionShape ->
|
||||||
|
// JSON serialize the structure or union targetted
|
||||||
|
rust(
|
||||||
|
"""#T(&$payloadName).expect("serialization should succeed")""",
|
||||||
|
serdeToVec
|
||||||
|
)
|
||||||
|
is DocumentShape ->
|
||||||
|
rustTemplate(
|
||||||
|
"""#{to_vec}(&#{doc_json}::SerDoc(&$payloadName)).expect("serialization should succeed")""",
|
||||||
|
"to_vec" to serdeToVec,
|
||||||
|
"doc_json" to RuntimeType.DocJson
|
||||||
|
)
|
||||||
|
else -> TODO("Unexpected payload target type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ import software.amazon.smithy.rust.codegen.util.orNull
|
||||||
import java.util.Optional
|
import java.util.Optional
|
||||||
import kotlin.streams.toList
|
import kotlin.streams.toList
|
||||||
|
|
||||||
typealias StructureModifier = (StructureShape?) -> StructureShape?
|
typealias StructureModifier = (OperationShape, StructureShape?) -> StructureShape?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate synthetic Input and Output structures for operations.
|
* Generate synthetic Input and Output structures for operations.
|
||||||
|
@ -64,6 +64,7 @@ class OperationNormalizer(private val model: Model) {
|
||||||
): List<StructureShape> {
|
): List<StructureShape> {
|
||||||
val outputId = operation.outputId()
|
val outputId = operation.outputId()
|
||||||
val outputBodyShape = outputBodyFactory(
|
val outputBodyShape = outputBodyFactory(
|
||||||
|
operation,
|
||||||
operation.output.map { model.expectShape(it, StructureShape::class.java) }.orNull()
|
operation.output.map { model.expectShape(it, StructureShape::class.java) }.orNull()
|
||||||
)?.let { it.toBuilder().addTrait(OutputBodyTrait()).rename(operation.outputBodyId()).build() }
|
)?.let { it.toBuilder().addTrait(OutputBodyTrait()).rename(operation.outputBodyId()).build() }
|
||||||
val outputShapeBuilder = operation.output.map { shapeId ->
|
val outputShapeBuilder = operation.output.map { shapeId ->
|
||||||
|
@ -79,6 +80,7 @@ class OperationNormalizer(private val model: Model) {
|
||||||
): List<StructureShape> {
|
): List<StructureShape> {
|
||||||
val inputId = operation.inputId()
|
val inputId = operation.inputId()
|
||||||
val inputBodyShape = inputBodyFactory(
|
val inputBodyShape = inputBodyFactory(
|
||||||
|
operation,
|
||||||
operation.input.map {
|
operation.input.map {
|
||||||
val inputShape = model.expectShape(it, StructureShape::class.java)
|
val inputShape = model.expectShape(it, StructureShape::class.java)
|
||||||
inputShape.toBuilder().addTrait(InputBodyTrait()).rename(operation.inputBodyId()).build()
|
inputShape.toBuilder().addTrait(InputBodyTrait()).rename(operation.inputBodyId()).build()
|
||||||
|
@ -101,7 +103,7 @@ class OperationNormalizer(private val model: Model) {
|
||||||
private fun OperationShape.inputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}InputBody")
|
private fun OperationShape.inputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}InputBody")
|
||||||
private fun OperationShape.outputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}OutputBody")
|
private fun OperationShape.outputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}OutputBody")
|
||||||
|
|
||||||
val NoBody: (StructureShape?) -> StructureShape? = { _ -> null }
|
val NoBody: StructureModifier = { _: OperationShape, _: StructureShape? -> null }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,14 @@
|
||||||
|
|
||||||
package software.amazon.smithy.rust.codegen.util
|
package software.amazon.smithy.rust.codegen.util
|
||||||
|
|
||||||
|
import software.amazon.smithy.codegen.core.CodegenException
|
||||||
import software.amazon.smithy.model.Model
|
import software.amazon.smithy.model.Model
|
||||||
|
import software.amazon.smithy.model.shapes.MemberShape
|
||||||
import software.amazon.smithy.model.shapes.OperationShape
|
import software.amazon.smithy.model.shapes.OperationShape
|
||||||
import software.amazon.smithy.model.shapes.Shape
|
import software.amazon.smithy.model.shapes.Shape
|
||||||
import software.amazon.smithy.model.shapes.ShapeId
|
import software.amazon.smithy.model.shapes.ShapeId
|
||||||
import software.amazon.smithy.model.shapes.StructureShape
|
import software.amazon.smithy.model.shapes.StructureShape
|
||||||
|
import software.amazon.smithy.model.shapes.UnionShape
|
||||||
|
|
||||||
inline fun <reified T : Shape> Model.lookup(shapeId: String): T {
|
inline fun <reified T : Shape> Model.lookup(shapeId: String): T {
|
||||||
return this.expectShape(ShapeId.from(shapeId), T::class.java)
|
return this.expectShape(ShapeId.from(shapeId), T::class.java)
|
||||||
|
@ -24,3 +27,9 @@ fun OperationShape.outputShape(model: Model): StructureShape {
|
||||||
// The Rust Smithy generator adds an output to all shapes automatically
|
// The Rust Smithy generator adds an output to all shapes automatically
|
||||||
return model.expectShape(this.output.get(), StructureShape::class.java)
|
return model.expectShape(this.output.get(), StructureShape::class.java)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun StructureShape.expectMember(member: String): MemberShape =
|
||||||
|
this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") }
|
||||||
|
|
||||||
|
fun UnionShape.expectMember(member: String): MemberShape =
|
||||||
|
this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") }
|
||||||
|
|
|
@ -129,7 +129,8 @@ class HttpProtocolTestGeneratorTest {
|
||||||
override fun toBodyImpl(
|
override fun toBodyImpl(
|
||||||
implBlockWriter: RustWriter,
|
implBlockWriter: RustWriter,
|
||||||
inputShape: StructureShape,
|
inputShape: StructureShape,
|
||||||
inputBody: StructureShape?
|
inputBody: StructureShape?,
|
||||||
|
operationShape: OperationShape
|
||||||
) {
|
) {
|
||||||
bodyBuilderFun(implBlockWriter) {
|
bodyBuilderFun(implBlockWriter) {
|
||||||
writeWithNoFormatting(body)
|
writeWithNoFormatting(body)
|
||||||
|
|
|
@ -112,10 +112,10 @@ internal class OperationNormalizerTest {
|
||||||
|
|
||||||
val sut = OperationNormalizer(model)
|
val sut = OperationNormalizer(model)
|
||||||
val modified = sut.transformModel(
|
val modified = sut.transformModel(
|
||||||
inputBodyFactory = { input ->
|
inputBodyFactory = { _, input ->
|
||||||
input?.toBuilder()?.members(input.members().filter { it.memberName != "drop" })?.build()
|
input?.toBuilder()?.members(input.members().filter { it.memberName != "drop" })?.build()
|
||||||
},
|
},
|
||||||
outputBodyFactory = { it?.toBuilder()?.members(emptyList())?.build() }
|
outputBodyFactory = { _, output -> output?.toBuilder()?.members(emptyList())?.build() }
|
||||||
)
|
)
|
||||||
val operation = modified.lookup<OperationShape>("smithy.test#MyOp")
|
val operation = modified.lookup<OperationShape>("smithy.test#MyOp")
|
||||||
operation.input.isPresent shouldBe true
|
operation.input.isPresent shouldBe true
|
||||||
|
|
|
@ -251,7 +251,7 @@ fn validate_json_body(actual: &str, expected: &str) -> Result<(), ProtocolTestFa
|
||||||
let actual_json: serde_json::Value =
|
let actual_json: serde_json::Value =
|
||||||
serde_json::from_str(actual).map_err(|e| ProtocolTestFailure::InvalidBodyFormat {
|
serde_json::from_str(actual).map_err(|e| ProtocolTestFailure::InvalidBodyFormat {
|
||||||
expected: "json".to_owned(),
|
expected: "json".to_owned(),
|
||||||
found: e.to_string(),
|
found: e.to_string() + actual,
|
||||||
})?;
|
})?;
|
||||||
let expected_json: serde_json::Value =
|
let expected_json: serde_json::Value =
|
||||||
serde_json::from_str(expected).expect("expected value must be valid JSON");
|
serde_json::from_str(expected).expect("expected value must be valid JSON");
|
||||||
|
|
Loading…
Reference in New Issue