Test `httpRequestTests` against actual Services (#1708)

* Make Instantiator generate default values for required field on demand

* Move looping over operations into ServerProtocolTestGenerator

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>

* Add protocol test helper functions

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>

* Add method param to construct http request

* Put request validation logic inside closure

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>

* Make protocol test response instantiate with default values

* Add module meta for helper module

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: david-perez <d@vidp.dev>

* Address most style suggestions

* add companion object for attribute #[allow(dead_code)]

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>

* Use writable to make code readable

* recursively call `filldefaultValue`

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>

* Exercise with `OperationExtension`

* Temporary protocol tests fix for awslabs/smithy#1391

Missing `X-Amz-Target` in response header

* Add `X-Amz-Target` for common models

Signed-off-by: Weihang Lo <weihanglo@users.noreply.github.com>
Co-authored-by: david-perez <d@vidp.dev>
Co-authored-by: Harry Barber <hlbarber@amazon.co.uk>
This commit is contained in:
Weihang Lo 2022-09-09 21:07:21 +01:00 committed by GitHub
parent e009f3f47f
commit f27aa54650
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 398 additions and 83 deletions

View File

@ -384,6 +384,7 @@ sealed class Attribute {
*/
val NonExhaustive = Custom("non_exhaustive")
val AllowUnusedMut = Custom("allow(unused_mut)")
val AllowDeadCode = Custom("allow(dead_code)")
val DocHidden = Custom("doc(hidden)")
val DocInline = Custom("doc(inline)")
}

View File

@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.client.smithy.generators
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.Node
@ -68,13 +69,19 @@ class Instantiator(
val streaming: Boolean,
// Whether we are instantiating with a Builder, in which case all setters take Option
val builder: Boolean,
// Fill out `required` fields with a default value.
val defaultsForRequiredFields: Boolean,
)
companion object {
fun defaultContext() = Ctx(lowercaseMapKeys = false, streaming = false, builder = false, defaultsForRequiredFields = false)
}
fun render(
writer: RustWriter,
shape: Shape,
arg: Node,
ctx: Ctx = Ctx(lowercaseMapKeys = false, streaming = false, builder = false),
ctx: Ctx = defaultContext(),
) {
when (shape) {
// Compound Shapes
@ -222,14 +229,23 @@ class Instantiator(
*/
private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) {
val unionSymbol = symbolProvider.toSymbol(shape)
check(data.members.size == 1)
val variant = data.members.iterator().next()
val memberName = variant.key.value
val variant = if (ctx.defaultsForRequiredFields && data.members.isEmpty()) {
val (name, memberShape) = shape.allMembers.entries.first()
val targetShape = model.expectShape(memberShape.target)
Node.from(name) to fillDefaultValue(targetShape)
} else {
check(data.members.size == 1)
val entry = data.members.iterator().next()
entry.key to entry.value
}
val memberName = variant.first.value
val member = shape.expectMember(memberName)
writer.write("#T::${symbolProvider.toMemberName(member)}", unionSymbol)
// unions should specify exactly one member
writer.withBlock("(", ")") {
renderMember(this, member, variant.value, ctx)
renderMember(this, member, variant.second, ctx)
}
}
@ -267,16 +283,54 @@ class Instantiator(
* ```
*/
private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) {
writer.write("#T::builder()", symbolProvider.toSymbol(shape))
data.members.forEach { (key, value) ->
val memberShape = shape.expectMember(key.value)
fun renderMemberHelper(memberShape: MemberShape, value: Node) {
writer.withBlock(".${memberShape.setterName()}(", ")") {
renderMember(this, memberShape, value, ctx)
}
}
writer.write("#T::builder()", symbolProvider.toSymbol(shape))
if (ctx.defaultsForRequiredFields) {
shape.allMembers.entries
.filter { (name, memberShape) ->
memberShape.isRequired && !data.members.containsKey(Node.from(name))
}
.forEach { (_, memberShape) ->
renderMemberHelper(memberShape, fillDefaultValue(memberShape))
}
}
data.members.forEach { (key, value) ->
val memberShape = shape.expectMember(key.value)
renderMemberHelper(memberShape, value)
}
writer.write(".build()")
if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) {
writer.write(".unwrap()")
}
}
/**
* Returns a default value for a shape.
*
* Warning: this method does not take into account any constraint traits attached to the shape.
*/
private fun fillDefaultValue(shape: Shape): Node = when (shape) {
is MemberShape -> fillDefaultValue(model.expectShape(shape.target))
// Aggregate shapes.
is StructureShape -> Node.objectNode()
is UnionShape -> Node.objectNode()
is CollectionShape -> Node.arrayNode()
is MapShape -> Node.objectNode()
// Simple Shapes
is TimestampShape -> Node.from(0) // Number node for timestamp
is BlobShape -> Node.from("") // String node for bytes
is StringShape -> Node.from("")
is NumberShape -> Node.from(0)
is BooleanShape -> Node.from(false)
is DocumentShape -> Node.objectNode()
else -> throw CodegenException("Unrecognized shape `$shape`")
}
}

View File

@ -66,6 +66,41 @@ class InstantiatorTest {
member: WithBox,
value: Integer
}
structure MyStructRequired {
@required
str: String,
@required
primitiveInt: PrimitiveInteger,
@required
int: Integer,
@required
ts: Timestamp,
@required
byte: Byte
@required
union: NestedUnion,
@required
structure: NestedStruct,
@required
list: MyList,
@required
map: NestedMap,
@required
doc: Document
}
union NestedUnion {
struct: NestedStruct,
int: Integer
}
structure NestedStruct {
@required
str: String,
@required
num: Integer
}
""".asSmithyModel().let { RecursiveShapeBoxer.transform(it) }
private val symbolProvider = testSymbolProvider(model)
@ -236,4 +271,51 @@ class InstantiatorTest {
}
writer.compileAndTest()
}
@Test
fun `generate struct with missing required members`() {
val structure = model.lookup<StructureShape>("com.test#MyStructRequired")
val inner = model.lookup<StructureShape>("com.test#Inner")
val nestedStruct = model.lookup<StructureShape>("com.test#NestedStruct")
val union = model.lookup<UnionShape>("com.test#NestedUnion")
val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER)
val data = Node.parse("{}")
val writer = RustWriter.forModule("model")
structure.renderWithModelBuilder(model, symbolProvider, writer)
inner.renderWithModelBuilder(model, symbolProvider, writer)
nestedStruct.renderWithModelBuilder(model, symbolProvider, writer)
UnionGenerator(model, symbolProvider, writer, union).render()
writer.test {
writer.withBlock("let result = ", ";") {
sut.render(this, structure, data, Instantiator.defaultContext().copy(defaultsForRequiredFields = true))
}
writer.write(
"""
use std::collections::HashMap;
use aws_smithy_types::{DateTime, Document};
let expected = MyStructRequired {
str: Some("".into()),
primitive_int: 0,
int: Some(0),
ts: Some(DateTime::from_secs(0)),
byte: Some(0),
union: Some(NestedUnion::Struct(NestedStruct {
str: Some("".into()),
num: Some(0),
})),
structure: Some(NestedStruct {
str: Some("".into()),
num: Some(0),
}),
list: Some(vec![]),
map: Some(HashMap::new()),
doc: Some(Document::Object(HashMap::new())),
};
assert_eq!(result, expected);
""",
)
}
writer.compileAndTest()
}
}

View File

@ -34,7 +34,10 @@ service Config {
uri: "/",
body: "{\"as\": 5, \"async\": true}",
bodyMediaType: "application/json",
headers: {"Content-Type": "application/x-amz-json-1.1"}
headers: {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": "Config.ReservedWordsAsMembers",
},
}
])
operation ReservedWordsAsMembers {
@ -78,7 +81,10 @@ structure Type {
uri: "/",
body: "{\"regular_string\": \"hello!\"}",
bodyMediaType: "application/json",
headers: {"Content-Type": "application/x-amz-json-1.1"}
headers: {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": "Config.StructureNamePunning",
},
}
])
operation StructureNamePunning {

View File

@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.client.rustlang.RustModule
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.DefaultPublicModules
import software.amazon.smithy.rust.codegen.client.smithy.RustCrate
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport
@ -37,15 +38,11 @@ open class ServerServiceGenerator(
* which assigns a symbol location to each shape.
*/
fun render() {
rustCrate.withModule(DefaultPublicModules["operation"]!!) { writer ->
ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer)
}
for (operation in operations) {
rustCrate.useShapeWriter(operation) { operationWriter ->
protocolGenerator.serverRenderOperation(
operationWriter,
operation,
)
ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, operation, operationWriter)
.render()
}
if (operation.errors.isNotEmpty()) {
rustCrate.withModule(RustModule.Error) { writer ->
renderCombinedErrors(writer, operation)

View File

@ -5,7 +5,9 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
@ -24,6 +26,7 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.client.rustlang.Attribute
import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.client.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.client.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.rustlang.Visibility
import software.amazon.smithy.rust.codegen.client.rustlang.asType
@ -32,10 +35,12 @@ import software.amazon.smithy.rust.codegen.client.rustlang.rust
import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.client.rustlang.withBlock
import software.amazon.smithy.rust.codegen.client.rustlang.writable
import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.client.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.client.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.client.smithy.transformers.allErrors
import software.amazon.smithy.rust.codegen.client.testutil.TokioTest
@ -49,29 +54,46 @@ import software.amazon.smithy.rust.codegen.client.util.orNull
import software.amazon.smithy.rust.codegen.client.util.outputShape
import software.amazon.smithy.rust.codegen.client.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator
import java.util.logging.Logger
import kotlin.reflect.KFunction1
private const val PROTOCOL_TEST_HELPER_MODULE_NAME = "protocol_test_helper"
/**
* Generate protocol tests for an operation
*/
class ServerProtocolTestGenerator(
private val coreCodegenContext: CoreCodegenContext,
private val protocolSupport: ProtocolSupport,
private val operationShape: OperationShape,
private val writer: RustWriter,
private val protocolGenerator: ProtocolGenerator,
) {
private val logger = Logger.getLogger(javaClass.name)
private val model = coreCodegenContext.model
private val inputShape = operationShape.inputShape(coreCodegenContext.model)
private val outputShape = operationShape.outputShape(coreCodegenContext.model)
private val symbolProvider = coreCodegenContext.symbolProvider
private val operationSymbol = symbolProvider.toSymbol(operationShape)
private val operationIndex = OperationIndex.of(coreCodegenContext.model)
private val operationImplementationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
private val operationErrorName = "crate::error::${operationSymbol.name}Error"
private val operations = TopDownIndex.of(coreCodegenContext.model).getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id }
private val operationInputOutputTypes = operations.associateWith {
val inputSymbol = symbolProvider.toSymbol(it.inputShape(model))
val outputSymbol = symbolProvider.toSymbol(it.outputShape(model))
val operationSymbol = symbolProvider.toSymbol(it)
val inputT = inputSymbol.fullName
val t = outputSymbol.fullName
val outputT = if (it.errors.isEmpty()) {
t
} else {
val errorType = RuntimeType("${operationSymbol.name}Error", null, "crate::error")
val e = errorType.fullyQualifiedName()
"Result<$t, $e>"
}
inputT to outputT
}
private val instantiator = with(coreCodegenContext) {
Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER)
@ -82,8 +104,10 @@ class ServerProtocolTestGenerator(
"SmithyHttp" to CargoDependency.SmithyHttp(coreCodegenContext.runtimeConfig).asType(),
"Http" to CargoDependency.Http.asType(),
"Hyper" to CargoDependency.Hyper.asType(),
"Tower" to CargoDependency.Tower.asType(),
"SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(coreCodegenContext.runtimeConfig).asType(),
"AssertEq" to CargoDependency.PrettyAssertions.asType().member("assert_eq!"),
"Router" to ServerRuntimeType.Router(coreCodegenContext.runtimeConfig),
)
sealed class TestCase {
@ -92,7 +116,7 @@ class ServerProtocolTestGenerator(
abstract val protocol: ShapeId
abstract val testType: TestType
data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() {
data class RequestTest(val testCase: HttpRequestTestCase, val operationShape: OperationShape) : TestCase() {
override val id: String = testCase.id
override val documentation: String? = testCase.documentation.orNull()
override val protocol: ShapeId = testCase.protocol
@ -114,9 +138,96 @@ class ServerProtocolTestGenerator(
}
}
fun render() {
fun render(writer: RustWriter) {
renderTestHelper(writer)
for (operation in operations) {
protocolGenerator.serverRenderOperation(writer, operation)
renderOperationTestCases(operation, writer)
}
}
/**
* Render a test helper module to:
*
* - generate a dynamic builder for each handler, and
* - construct a Tower service to exercise each test case.
*/
private fun renderTestHelper(writer: RustWriter) {
val operationNames = operations.map { it.toName() }
val operationRegistryName = "OperationRegistry"
val operationRegistryBuilderName = "${operationRegistryName}Builder"
fun renderRegistryBuilderTypeParams() = writable {
operations.forEach {
val (inputT, outputT) = operationInputOutputTypes[it]!!
writeInline("Fun<$inputT, $outputT>, (), ")
}
}
fun renderRegistryBuilderMethods() = writable {
operations.withIndex().forEach {
val (inputT, outputT) = operationInputOutputTypes[it.value]!!
val operationName = operationNames[it.index]
write(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )")
}
}
val moduleMeta = RustMetadata(
additionalAttributes = listOf(
Attribute.Cfg("test"),
Attribute.AllowDeadCode,
),
visibility = Visibility.PUBCRATE,
)
writer.withModule(PROTOCOL_TEST_HELPER_MODULE_NAME, moduleMeta) {
rustTemplate(
"""
use #{Tower}::Service as _;
pub(crate) type Fun<Input, Output> = fn(Input) -> std::pin::Pin<Box<dyn std::future::Future<Output = Output> + Send>>;
type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, #{RegistryBuilderTypeParams:W}>;
fn create_operation_registry_builder() -> RegistryBuilder {
crate::operation_registry::$operationRegistryBuilderName::default()
#{RegistryBuilderMethods:W}
}
/// The operation full name is a concatenation of `<operation namespace>.<operation name>`.
pub(crate) async fn build_router_and_make_request(
http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>,
operation_full_name: &str,
f: &dyn Fn(RegistryBuilder) -> RegistryBuilder,
) {
let mut router: #{Router} = f(create_operation_registry_builder())
.build()
.expect("unable to build operation registry")
.into();
let http_response = router
.call(http_request)
.await
.expect("unable to make an HTTP request");
let operation_extension = http_response.extensions()
.get::<#{SmithyHttpServer}::extension::OperationExtension>()
.expect("extension `OperationExtension` not found");
#{AssertEq}(operation_extension.absolute(), operation_full_name);
}
""",
"RegistryBuilderTypeParams" to renderRegistryBuilderTypeParams(),
"RegistryBuilderMethods" to renderRegistryBuilderMethods(),
*codegenScope,
)
}
}
private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) {
val outputShape = operationShape.outputShape(coreCodegenContext.model)
val operationSymbol = symbolProvider.toSymbol(operationShape)
val requestTests = operationShape.getTrait<HttpRequestTestsTrait>()
?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it) }
?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it, operationShape) }
val responseTests = operationShape.getTrait<HttpResponseTestsTrait>()
?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.ResponseTest(it, outputShape) }
val errorTests = operationIndex.getErrors(operationShape).flatMap { error ->
@ -141,23 +252,26 @@ class ServerProtocolTestGenerator(
visibility = Visibility.PRIVATE,
)
writer.withModule(testModuleName, moduleMeta) {
renderAllTestCases(allTests)
renderAllTestCases(operationShape, allTests)
}
}
}
private fun RustWriter.renderAllTestCases(allTests: List<TestCase>) {
private fun RustWriter.renderAllTestCases(operationShape: OperationShape, allTests: List<TestCase>) {
allTests.forEach {
val operationSymbol = symbolProvider.toSymbol(operationShape)
renderTestCaseBlock(it, this) {
when (it) {
is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase)
is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape)
is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase)
is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase, operationShape, operationSymbol)
is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape, operationShape, operationSymbol)
is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase, operationSymbol)
}
}
}
}
private fun OperationShape.toName(): String = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(this).name.toSnakeCase())
/**
* Filter out test cases that are disabled or don't match the service protocol
*/
@ -182,8 +296,8 @@ class ServerProtocolTestGenerator(
if (howToFixIt == null) {
it
} else {
val fixed = howToFixIt(it.testCase)
TestCase.RequestTest(fixed)
val fixed = howToFixIt(it.testCase, it.operationShape)
TestCase.RequestTest(fixed, it.operationShape)
}
}
is TestCase.ResponseTest -> {
@ -236,16 +350,18 @@ class ServerProtocolTestGenerator(
*/
private fun RustWriter.renderHttpRequestTestCase(
httpRequestTestCase: HttpRequestTestCase,
operationShape: OperationShape,
operationSymbol: Symbol,
) {
if (!protocolSupport.requestDeserialization) {
rust("/* test case disabled for this protocol (not yet supported) */")
return
}
with(httpRequestTestCase) {
renderHttpRequest(uri, headers, body.orNull(), queryParams, host.orNull())
renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
}
if (protocolSupport.requestBodyDeserialization) {
checkParams(httpRequestTestCase, this)
checkRequest(operationShape, operationSymbol, httpRequestTestCase, this)
}
// Explicitly warn if the test case defined parameters that we aren't doing anything with
@ -272,7 +388,12 @@ class ServerProtocolTestGenerator(
private fun RustWriter.renderHttpResponseTestCase(
testCase: HttpResponseTestCase,
shape: StructureShape,
operationShape: OperationShape,
operationSymbol: Symbol,
) {
val operationImplementationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
val operationErrorName = "crate::error::${operationSymbol.name}Error"
if (!protocolSupport.responseSerialization || (
!protocolSupport.errorSerialization && shape.hasTrait<ErrorTrait>()
)
@ -308,10 +429,10 @@ class ServerProtocolTestGenerator(
* We are given a request definition and a response definition, and we have to assert that the request is rejected
* with the given response.
*/
private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) {
private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase, operationSymbol: Symbol) {
with(testCase.request) {
// TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`.
renderHttpRequest(uri.get(), headers, body.orNull(), queryParams, host.orNull())
renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull())
}
val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
@ -328,6 +449,7 @@ class ServerProtocolTestGenerator(
private fun RustWriter.renderHttpRequest(
uri: String,
method: String,
headers: Map<String, String>,
body: String?,
queryParams: List<String>,
@ -338,6 +460,7 @@ class ServerProtocolTestGenerator(
##[allow(unused_mut)]
let mut http_request = http::Request::builder()
.uri("$uri")
.method("$method")
""",
*codegenScope,
)
@ -372,20 +495,44 @@ class ServerProtocolTestGenerator(
}
}
private fun checkParams(httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) {
rustWriter.writeInline("let expected = ")
instantiator.render(rustWriter, inputShape, httpRequestTestCase.params)
rustWriter.write(";")
private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) {
val inputShape = operationShape.inputShape(coreCodegenContext.model)
val outputShape = operationShape.outputShape(coreCodegenContext.model)
val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
rustWriter.rustTemplate(
val (inputT, outputT) = operationInputOutputTypes[operationShape]!!
rustWriter.withBlock(
"""
let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request);
let parsed = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request").0;
super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request(
http_request,
"${operationShape.id.namespace}.${operationSymbol.name}",
&|builder| {
builder.${operationShape.toName()}((|input| Box::pin(async move {
""",
*codegenScope,
)
"})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await",
) {
// Construct expected request.
rustWriter.withBlock("let expected = ", ";") {
instantiator.render(this, inputShape, httpRequestTestCase.params)
}
checkRequestParams(inputShape, rustWriter)
// Construct a dummy response.
rustWriter.withBlock("let response = ", ";") {
instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true))
}
if (operationShape.errors.isEmpty()) {
rustWriter.write("response")
} else {
rustWriter.write("Ok(response)")
}
}
}
private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) {
if (inputShape.hasStreamingMember(model)) {
// A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
// and handle the equality assertion separately.
@ -395,7 +542,7 @@ class ServerProtocolTestGenerator(
rustWriter.rustTemplate(
"""
#{AssertEq}(
parsed.$memberName.collect().await.unwrap().into_bytes(),
input.$memberName.collect().await.unwrap().into_bytes(),
expected.$memberName.collect().await.unwrap().into_bytes()
);
""",
@ -404,7 +551,7 @@ class ServerProtocolTestGenerator(
} else {
rustWriter.rustTemplate(
"""
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
#{AssertEq}(input.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""",
*codegenScope,
)
@ -423,19 +570,21 @@ class ServerProtocolTestGenerator(
when (coreCodegenContext.model.expectShape(member.target)) {
is DoubleShape, is FloatShape -> {
rustWriter.addUseImports(
RuntimeType.ProtocolTestHelper(coreCodegenContext.runtimeConfig, "FloatEquals").toSymbol(),
RuntimeType.ProtocolTestHelper(coreCodegenContext.runtimeConfig, "FloatEquals")
.toSymbol(),
)
rustWriter.rust(
"""
assert!(parsed.$memberName.float_equals(&expected.$memberName),
"Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, parsed.$memberName);
assert!(input.$memberName.float_equals(&expected.$memberName),
"Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, input.$memberName);
""",
)
}
else -> {
rustWriter.rustTemplate(
"""
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
#{AssertEq}(input.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""",
*codegenScope,
)
@ -443,7 +592,7 @@ class ServerProtocolTestGenerator(
}
}
} else {
rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope)
}
}
}
@ -457,7 +606,8 @@ class ServerProtocolTestGenerator(
// We can't check that the `OperationExtension` is set in the response, because it is set in the implementation
// of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to
// invoke it with (like in the case of an `httpResponseTest` test case).
// checkHttpOperationExtension(rustWriter)
// In https://github.com/awslabs/smithy-rs/pull/1708: We did change `httpResponseTest`s generation to `call()`
// the operation handler trait implementation instead of directly calling `from_request()`.
// If no request body is defined, then no assertions are made about the body of the message.
if (testCase.body.isPresent) {
@ -470,11 +620,10 @@ class ServerProtocolTestGenerator(
checkHeaders(rustWriter, "&http_response.headers()", testCase.headers)
// We can't check that the `OperationExtension` is set in the response, because it is set in the implementation
// of the operation `Handler` trait, a code path that does not get exercised by `httpRequestTest` test cases.
// TODO(https://github.com/awslabs/smithy-rs/issues/1212): We could change test case generation so as to `call()`
// the operation handler trait implementation instead of directly calling `from_request()`, or we could run an
// actual service.
// checkHttpOperationExtension(rustWriter)
// of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to
// invoke it with (like in the case of an `httpResponseTest` test case).
// In https://github.com/awslabs/smithy-rs/pull/1708: We did change `httpResponseTest`s generation to `call()`
// the operation handler trait implementation instead of directly calling `from_request()`.
// If no request body is defined, then no assertions are made about the body of the message.
if (testCase.body.isEmpty) return
@ -522,22 +671,6 @@ class ServerProtocolTestGenerator(
}
}
private fun checkHttpOperationExtension(rustWriter: RustWriter) {
rustWriter.rustTemplate(
"""
let operation_extension = http_response.extensions()
.get::<#{SmithyHttpServer}::extension::OperationExtension>()
.expect("extension `OperationExtension` not found");
""".trimIndent(),
*codegenScope,
)
rustWriter.writeWithNoFormatting(
"""
assert_eq!(operation_extension.absolute(), format!("{}.{}", "${operationShape.id.namespace}", "${operationSymbol.name}"));
""".trimIndent(),
)
}
private fun checkStatusCode(rustWriter: RustWriter, statusCode: Int) {
rustWriter.rustTemplate(
"""
@ -782,7 +915,7 @@ class ServerProtocolTestGenerator(
// or because they are flaky
private val DisableTests = setOf<String>()
private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase {
private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase {
val params = Node.parse(
"""
{
@ -798,7 +931,7 @@ class ServerProtocolTestGenerator(
return testCase.toBuilder().params(params).build()
}
private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase =
private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase =
testCase.toBuilder().params(
Node.parse(
"""
@ -813,7 +946,7 @@ class ServerProtocolTestGenerator(
""".trimMargin(),
).asObjectNode().get(),
).build()
private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase =
private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase =
testCase.toBuilder().params(
Node.parse(
"""
@ -828,7 +961,7 @@ class ServerProtocolTestGenerator(
""".trimMargin(),
).asObjectNode().get(),
).build()
private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase =
private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase =
testCase.toBuilder().params(
Node.parse(
"""
@ -875,7 +1008,7 @@ class ServerProtocolTestGenerator(
""".trimMargin(),
).asObjectNode().get(),
).build()
private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase =
private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase =
testCase.toBuilder().params(
Node.parse(
"""
@ -889,6 +1022,9 @@ class ServerProtocolTestGenerator(
).asObjectNode().get(),
).build()
private fun fixAwsJson11MissingHeaderXAmzTarget(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase =
testCase.toBuilder().putHeader("x-amz-target", "JsonProtocol.${operationShape.id.name}").build()
// These are tests whose definitions in the `awslabs/smithy` repository are wrong.
// This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete.
// TODO(https://github.com/awslabs/smithy-rs/issues/1288): Contribute a PR to fix them upstream.
@ -899,6 +1035,45 @@ class ServerProtocolTestGenerator(
Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues,
Pair(RestJson, "RestJsonAllQueryStringTypes") to ::fixRestJsonAllQueryStringTypes,
Pair(RestJson, "RestJsonQueryStringEscaping") to ::fixRestJsonQueryStringEscaping,
// https://github.com/awslabs/smithy/pull/1392
// Missing `X-Amz-Target` in response header
Pair(AwsJson11, "AwsJson11Enums") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "AwsJson11ListsSerializeNull") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "AwsJson11MapsSerializeNullValues") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "AwsJson11ServersDontDeserializeNullStructureValues") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "PutAndGetInlineDocumentsInput") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "json_1_1_client_sends_empty_payload_for_no_input_shape") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "json_1_1_service_supports_empty_payload_for_no_input_shape") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "sends_requests_to_slash") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_blob_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_boolean_shapes_false") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_boolean_shapes_true") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_double_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_empty_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_empty_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_empty_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_float_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_integer_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_list_of_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_list_of_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_list_of_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_long_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_map_of_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_map_of_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_map_of_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_string_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_string_shapes_with_jsonvalue_trait") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_structure_members_with_locationname_traits") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_structure_which_have_no_members") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_timestamp_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_timestamp_shapes_with_httpdate_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_timestamp_shapes_with_iso8601_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget,
Pair(AwsJson11, "serializes_timestamp_shapes_with_unixtimestamp_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget,
)
private val BrokenResponseTests: Map<Pair<String, String>, KFunction1<HttpResponseTestCase, HttpResponseTestCase>> = mapOf()