Serialize Request bodies for restJson (#255)

* Serialize Request Bodies for RestJson

* Move the bodies to the serializer module

* RestJson body CR feedback and bug fixes
This commit is contained in:
Russell Cohen 2021-03-22 16:40:50 -04:00 committed by GitHub
parent 4f958f3709
commit 41d948c597
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 266 additions and 28 deletions

View File

@ -36,6 +36,10 @@ val CodegenTests = listOf(
"aws.protocoltests.restjson#RestJson",
"rest_json"
),
CodegenTest(
"aws.protocoltests.restjson#RestJsonExtras",
"rest_json_extas"
),
CodegenTest(
"crate#Config",
"naming_test", """

View File

@ -0,0 +1,59 @@
$version: "1.0"
namespace aws.protocoltests.restjson
use aws.protocols#restJson1
use aws.api#service
use smithy.test#httpRequestTests
/// A REST JSON service that sends JSON requests and responses.
@service(sdkId: "Rest Json Protocol")
@restJson1
service RestJsonExtras {
version: "2019-12-16",
operations: [EnumPayload, StringPayload]
}
@http(uri: "/EnumPayload", method: "POST")
@httpRequestTests([
{
id: "EnumPayload",
uri: "/EnumPayload",
body: "enumvalue",
params: { payload: "enumvalue" },
method: "POST",
protocol: "aws.protocols#restJson1"
}
])
operation EnumPayload {
input: EnumPayloadInput
}
structure EnumPayloadInput {
@httpPayload
payload: StringEnum
}
@enum([{"value": "enumvalue", "name": "V"}])
string StringEnum
@http(uri: "/StringPayload", method: "POST")
@httpRequestTests([
{
id: "StringPayload",
uri: "/StringPayload",
body: "rawstring",
params: { payload: "rawstring" },
method: "POST",
protocol: "aws.protocols#restJson1"
}
])
operation StringPayload {
input: StringPayloadInput
}
structure StringPayloadInput {
@httpPayload
payload: String
}

View File

@ -40,6 +40,8 @@ import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.toSnakeCase
@ -231,6 +233,7 @@ class SymbolVisitor(
val isError = shape.hasTrait(ErrorTrait::class.java)
val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
val isOutput = shape.hasTrait(SyntheticOutputTrait::class.java)
val isBody = shape.hasTrait(InputBodyTrait::class.java) || shape.hasTrait(OutputBodyTrait::class.java)
val name = StringUtils.capitalize(shape.id.name).letIf(isError && config.codegenConfig.renameExceptions) {
// TODO: Do we want to do this?
// https://github.com/awslabs/smithy-rs/issues/77
@ -241,6 +244,7 @@ class SymbolVisitor(
isError -> builder.locatedIn(Errors)
isInput -> builder.locatedIn(Inputs)
isOutput -> builder.locatedIn(Outputs)
isBody -> builder.locatedIn(Serializers)
else -> builder.locatedIn(Models)
}.build()
}

View File

@ -49,13 +49,24 @@ abstract class HttpProtocolGenerator(
) {
private val symbolProvider = protocolConfig.symbolProvider
private val model = protocolConfig.model
fun renderOperation(operationWriter: RustWriter, inputWriter: RustWriter, operationShape: OperationShape, customizations: List<OperationCustomization>) {
fun renderOperation(
operationWriter: RustWriter,
inputWriter: RustWriter,
operationShape: OperationShape,
customizations: List<OperationCustomization>
) {
/* if (operationShape.hasTrait(EndpointTrait::class.java)) {
TODO("https://github.com/awslabs/smithy-rs/issues/197")
} */
val inputShape = operationShape.inputShape(model)
val inputSymbol = symbolProvider.toSymbol(inputShape)
val builderGenerator = OperationInputBuilderGenerator(model, symbolProvider, operationShape, protocolConfig.moduleName, customizations)
val builderGenerator = OperationInputBuilderGenerator(
model,
symbolProvider,
operationShape,
protocolConfig.moduleName,
customizations
)
builderGenerator.render(inputWriter)
// impl OperationInputShape { ... }
@ -63,7 +74,7 @@ abstract class HttpProtocolGenerator(
toHttpRequestImpl(this, operationShape, inputShape)
val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body
val body = shapeId?.let { model.expectShape(it, StructureShape::class.java) }
toBodyImpl(this, inputShape, body)
toBodyImpl(this, inputShape, body, operationShape)
// TODO: streaming shapes need special support
rustBlock(
"pub fn assemble(builder: #1T, body: #3T) -> #2T<#3T>",
@ -130,14 +141,18 @@ abstract class HttpProtocolGenerator(
}
}
protected fun fromResponseFun(implBlockWriter: RustWriter, operationShape: OperationShape, f: RustWriter.() -> Unit) {
protected fun fromResponseFun(
implBlockWriter: RustWriter,
operationShape: OperationShape,
block: RustWriter.() -> Unit
) {
implBlockWriter.rustBlock(
"fn from_response(response: &#T<impl AsRef<[u8]>>) -> Result<#T, #T>",
RuntimeType.Http("response::Response"),
symbolProvider.toSymbol(operationShape.outputShape(model)),
operationShape.errorSymbol(symbolProvider)
) {
f(this)
block(this)
}
}
@ -150,12 +165,21 @@ abstract class HttpProtocolGenerator(
*
* Your implementation MUST call [bodyBuilderFun] to create the public method.
*/
abstract fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?)
abstract fun toBodyImpl(
implBlockWriter: RustWriter,
inputShape: StructureShape,
inputBody: StructureShape?,
operationShape: OperationShape
)
/**
* Add necessary methods to the impl block for the input shape.
*
* Your implementation MUST call [httpBuilderFun] to create the public method.
*/
abstract fun toHttpRequestImpl(implBlockWriter: RustWriter, operationShape: OperationShape, inputShape: StructureShape)
abstract fun toHttpRequestImpl(
implBlockWriter: RustWriter,
operationShape: OperationShape,
inputShape: StructureShape
)
}

View File

@ -172,7 +172,7 @@ class HttpProtocolTestGenerator(
checkRequiredHeaders(this, httpRequestTestCase.requireHeaders)
if (protocolSupport.requestBodySerialization) {
// "If no request body is defined, then no assertions are made about the body of the message."
httpRequestTestCase.body.orNull()?.let { body ->
httpRequestTestCase.body.orNull()?.also { body ->
checkBody(this, body, httpRequestTestCase.bodyMediaType.orNull())
}
}
@ -253,7 +253,7 @@ class HttpProtocolTestGenerator(
rustWriter.write("""let body = http_request.body().bytes().expect("body should be strict");""")
if (body == "") {
rustWriter.write("// No body")
rustWriter.write("assert!(&body.is_empty());")
rustWriter.write("assert_eq!(std::str::from_utf8(body).unwrap(), ${"".dq()});")
} else {
// When we generate a body instead of a stub, drop the trailing `;` and enable the assertion
assertOk(rustWriter) {
@ -383,7 +383,9 @@ class HttpProtocolTestGenerator(
// or because they are flaky
private val DisableTests = setOf(
// This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37
"AwsJson11Enums"
"AwsJson11Enums",
"RestJsonJsonEnums",
"RestJsonLists"
)
}
}

View File

@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
fun HttpTrait.uriFormatString(): String {
val base = uri.segments.joinToString("/", prefix = "/") {
@ -160,7 +161,7 @@ class HttpTraitBindingGenerator(
private fun uriBase(writer: RustWriter) {
val formatString = httpTrait.uriFormatString()
val args = httpTrait.uri.labels.map { label ->
val member = inputShape.getMember(label.content).get()
val member = inputShape.expectMember(label.content)
"${label.content} = ${labelFmtFun(model.expectShape(member.target), member, label)}"
}
val combinedArgs = listOf(formatString, *args.toTypedArray())

View File

@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.toPascalCase
/**
@ -187,7 +188,7 @@ class Instantiator(
check(data.members.size == 1)
val variant = data.members.iterator().next()
val memberName = variant.key.value
val member = shape.getMember(memberName).get()
val member = shape.expectMember(memberName)
.let { model.expectShape(it.target) }
// TODO: refactor this detail into UnionGenerator
writer.write("#T::${memberName.toPascalCase()}", unionSymbol)
@ -278,8 +279,7 @@ class Instantiator(
}
private fun getMember(shape: StructureShape, key: StringNode): Pair<MemberShape, Shape> {
val memberShape = shape.getMember(key.value)
.orElseThrow { IllegalArgumentException("$shape did not have member ${key.value}") }
val memberShape = shape.expectMember(key.value)
return memberShape to model.expectShape(memberShape.target)
}
}

View File

@ -56,7 +56,7 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat
protocolConfig: ProtocolConfig
): BasicAwsJsonGenerator = BasicAwsJsonGenerator(protocolConfig, version)
private val shapeIfHasMembers: StructureModifier = { shape: StructureShape? ->
private val shapeIfHasMembers: StructureModifier = { _, shape: StructureShape? ->
if (shape?.members().isNullOrEmpty()) {
null
} else {
@ -178,7 +178,12 @@ class BasicAwsJsonGenerator(
}
}
override fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?) {
override fun toBodyImpl(
implBlockWriter: RustWriter,
inputShape: StructureShape,
inputBody: StructureShape?,
operationShape: OperationShape
) {
if (inputBody == null) {
bodyBuilderFun(implBlockWriter) {
write("\"{}\".to_string().into()")

View File

@ -6,30 +6,64 @@
package software.amazon.smithy.rust.codegen.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.HttpTraitBindingGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {
override fun buildProtocolGenerator(
protocolConfig: ProtocolConfig
): AwsRestJsonGenerator = AwsRestJsonGenerator(protocolConfig)
/** Create a synthetic awsJsonInputBody if specified
* A body is created iff no member of [input] is targeted with the `PAYLOAD` trait. If a member is targeted with
* the payload trait, we don't need to create an input body.
*/
private fun awsJsonInputBody(model: Model, operation: OperationShape, input: StructureShape?): StructureShape? {
if (input == null) {
return null
}
val bindingIndex = HttpBindingIndex.of(model)
val bindings: MutableMap<String, HttpBinding> = bindingIndex.getRequestBindings(operation)
val bodyMembers = input.members().filter { member ->
bindings[member.memberName]?.location == HttpBinding.Location.DOCUMENT
}
return if (bodyMembers.isNotEmpty()) {
input.toBuilder().members(bodyMembers).build()
} else {
null
}
}
override fun transformModel(model: Model): Model {
// TODO: AWSRestJson determines the body from HTTP traits
return OperationNormalizer(model).transformModel(
inputBodyFactory = OperationNormalizer.NoBody,
inputBodyFactory = { op, input -> awsJsonInputBody(model, op, input) },
outputBodyFactory = OperationNormalizer.NoBody
)
}
@ -37,15 +71,23 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {
override fun support(): ProtocolSupport {
// TODO: Support body for RestJson
return ProtocolSupport(
requestBodySerialization = false,
requestBodySerialization = true,
responseDeserialization = false,
errorDeserialization = false
)
}
override fun symbolProvider(model: Model, base: RustSymbolProvider): RustSymbolProvider {
return JsonSerializerSymbolProvider(
model,
SyntheticBodySymbolProvider(model, base),
TimestampFormatTrait.Format.EPOCH_SECONDS
)
}
}
class AwsRestJsonGenerator(
protocolConfig: ProtocolConfig
private val protocolConfig: ProtocolConfig
) : HttpProtocolGenerator(protocolConfig) {
// restJson1 requires all operations to use the HTTP trait
@ -66,9 +108,94 @@ class AwsRestJsonGenerator(
}
}
override fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?) {
private fun serializeViaSyntheticBody(
implBlockWriter: RustWriter,
inputBody: StructureShape
) {
val bodySymbol = protocolConfig.symbolProvider.toSymbol(inputBody)
implBlockWriter.rustBlock("fn body(&self) -> #T", bodySymbol) {
rustBlock("#T", bodySymbol) {
for (member in inputBody.members()) {
val name = protocolConfig.symbolProvider.toMemberName(member)
write("$name: &self.$name,")
}
}
}
bodyBuilderFun(implBlockWriter) {
rust(""""body not generated yet".into()""")
write("""#T(&self.body()).expect("serialization should succeed")""", RuntimeType.SerdeJson("to_vec"))
}
}
override fun toBodyImpl(
implBlockWriter: RustWriter,
inputShape: StructureShape,
inputBody: StructureShape?,
operationShape: OperationShape
) {
// If we created a synthetic input body, serialize that
if (inputBody != null) {
return serializeViaSyntheticBody(implBlockWriter, inputBody)
}
// Otherwise, we need to serialize via the HTTP payload trait
val bindings = httpIndex.getRequestBindings(operationShape).toList()
val payload: Pair<String, HttpBinding>? =
bindings.firstOrNull { (_, binding) -> binding.location == HttpBinding.Location.PAYLOAD }
val payloadSerde = payload?.let { (payloadMemberName, _) ->
val member = inputShape.expectMember(payloadMemberName)
val rustMemberName = "self.${symbolProvider.toMemberName(member)}"
val targetShape = model.expectShape(member.target)
writable {
val payloadName = safeName()
rust("let $payloadName = &$rustMemberName;")
// If this targets a member & the member is None, return an empty vec
if (symbolProvider.toSymbol(member).isOptional()) {
rust(
"""
let $payloadName = match $payloadName.as_ref() {
Some(t) => t,
None => return vec![]
};"""
)
}
renderPayload(targetShape, payloadName)
}
// body is null, no payload set, so this is empty
} ?: writable { rust("vec![]") }
bodyBuilderFun(implBlockWriter) {
payloadSerde(this)
}
}
private fun RustWriter.renderPayload(
targetShape: Shape,
payloadName: String,
) {
val serdeToVec = RuntimeType.SerdeJson("to_vec")
when (targetShape) {
// Write the raw string to the payload
is StringShape ->
if (targetShape.hasTrait(EnumTrait::class.java)) {
rust("$payloadName.as_str().into()")
} else {
rust("""$payloadName.to_string().into()""")
}
is BlobShape ->
// Write the raw blob to the payload
rust("$payloadName.as_ref().into()")
is StructureShape, is UnionShape ->
// JSON serialize the structure or union targetted
rust(
"""#T(&$payloadName).expect("serialization should succeed")""",
serdeToVec
)
is DocumentShape ->
rustTemplate(
"""#{to_vec}(&#{doc_json}::SerDoc(&$payloadName)).expect("serialization should succeed")""",
"to_vec" to serdeToVec,
"doc_json" to RuntimeType.DocJson
)
else -> TODO("Unexpected payload target type")
}
}

View File

@ -19,7 +19,7 @@ import software.amazon.smithy.rust.codegen.util.orNull
import java.util.Optional
import kotlin.streams.toList
typealias StructureModifier = (StructureShape?) -> StructureShape?
typealias StructureModifier = (OperationShape, StructureShape?) -> StructureShape?
/**
* Generate synthetic Input and Output structures for operations.
@ -64,6 +64,7 @@ class OperationNormalizer(private val model: Model) {
): List<StructureShape> {
val outputId = operation.outputId()
val outputBodyShape = outputBodyFactory(
operation,
operation.output.map { model.expectShape(it, StructureShape::class.java) }.orNull()
)?.let { it.toBuilder().addTrait(OutputBodyTrait()).rename(operation.outputBodyId()).build() }
val outputShapeBuilder = operation.output.map { shapeId ->
@ -79,6 +80,7 @@ class OperationNormalizer(private val model: Model) {
): List<StructureShape> {
val inputId = operation.inputId()
val inputBodyShape = inputBodyFactory(
operation,
operation.input.map {
val inputShape = model.expectShape(it, StructureShape::class.java)
inputShape.toBuilder().addTrait(InputBodyTrait()).rename(operation.inputBodyId()).build()
@ -101,7 +103,7 @@ class OperationNormalizer(private val model: Model) {
private fun OperationShape.inputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}InputBody")
private fun OperationShape.outputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}OutputBody")
val NoBody: (StructureShape?) -> StructureShape? = { _ -> null }
val NoBody: StructureModifier = { _: OperationShape, _: StructureShape? -> null }
}
}

View File

@ -5,11 +5,14 @@
package software.amazon.smithy.rust.codegen.util
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
inline fun <reified T : Shape> Model.lookup(shapeId: String): T {
return this.expectShape(ShapeId.from(shapeId), T::class.java)
@ -24,3 +27,9 @@ fun OperationShape.outputShape(model: Model): StructureShape {
// The Rust Smithy generator adds an output to all shapes automatically
return model.expectShape(this.output.get(), StructureShape::class.java)
}
fun StructureShape.expectMember(member: String): MemberShape =
this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") }
fun UnionShape.expectMember(member: String): MemberShape =
this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") }

View File

@ -129,7 +129,8 @@ class HttpProtocolTestGeneratorTest {
override fun toBodyImpl(
implBlockWriter: RustWriter,
inputShape: StructureShape,
inputBody: StructureShape?
inputBody: StructureShape?,
operationShape: OperationShape
) {
bodyBuilderFun(implBlockWriter) {
writeWithNoFormatting(body)

View File

@ -112,10 +112,10 @@ internal class OperationNormalizerTest {
val sut = OperationNormalizer(model)
val modified = sut.transformModel(
inputBodyFactory = { input ->
inputBodyFactory = { _, input ->
input?.toBuilder()?.members(input.members().filter { it.memberName != "drop" })?.build()
},
outputBodyFactory = { it?.toBuilder()?.members(emptyList())?.build() }
outputBodyFactory = { _, output -> output?.toBuilder()?.members(emptyList())?.build() }
)
val operation = modified.lookup<OperationShape>("smithy.test#MyOp")
operation.input.isPresent shouldBe true

View File

@ -251,7 +251,7 @@ fn validate_json_body(actual: &str, expected: &str) -> Result<(), ProtocolTestFa
let actual_json: serde_json::Value =
serde_json::from_str(actual).map_err(|e| ProtocolTestFailure::InvalidBodyFormat {
expected: "json".to_owned(),
found: e.to_string(),
found: e.to_string() + actual,
})?;
let expected_json: serde_json::Value =
serde_json::from_str(expected).expect("expected value must be valid JSON");