Consider `NaN` to be equal to itself in server protocol tests (#1177)

The client protocol tests use the
`aws_smithy_protocol_test::FloatEquals` for this [0].

Note we're only applying this to direct floating point shape members,
i.e. this commit _does not_ address #1147.

[0]: https://docs.rs/aws-smithy-protocol-test/latest/aws_smithy_protocol_test/trait.FloatEquals.html
This commit is contained in:
david-perez 2022-02-23 12:01:19 +01:00 committed by GitHub
parent e1099324e1
commit 97a49f3c12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 10 deletions

View File

@ -8,6 +8,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
@ -357,7 +359,40 @@ class ServerProtocolTestGenerator(
}
}
} else {
rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
val hasFloatingPointMembers = inputShape.members().any {
val target = model.expectShape(it.target)
(target is DoubleShape) || (target is FloatShape)
}
// TODO(https://github.com/awslabs/smithy-rs/issues/1147) Handle the case of nested floating point members.
if (hasFloatingPointMembers) {
for (member in inputShape.members()) {
val memberName = codegenContext.symbolProvider.toMemberName(member)
when (codegenContext.model.expectShape(member.target)) {
is DoubleShape, is FloatShape -> {
rustWriter.addUseImports(
RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "FloatEquals").toSymbol()
)
rustWriter.rust(
"""
assert!(parsed.$memberName.float_equals(&expected.$memberName),
"Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, parsed.$memberName);
"""
)
}
else -> {
rustWriter.rustTemplate(
"""
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""",
*codegenScope
)
}
}
}
} else {
rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
}
}
}
@ -503,11 +538,9 @@ class ServerProtocolTestGenerator(
private val ExpectFail = setOf<FailingTest>(
// Headers.
FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", Action.Request),
FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),
FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
FailingTest(RestJson, "RestJsonFooErrorUsingCode", Action.Response),
@ -516,10 +549,8 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonFooErrorWithDunderType", Action.Response),
FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeAndNamespace", Action.Response),
FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeUriAndNamespace", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request),
FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),
@ -540,11 +571,6 @@ class ServerProtocolTestGenerator(
private val DisableTests = setOf<String>()
private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase {
// TODO This test does not pass, even after fixing it with this function, because, in IEEE 754 floating
// point numbers, `NaN` is not equal to any other floating point number, even itself! So we can't compare it
// to any "expected" value.
// Reference: https://doc.rust-lang.org/std/primitive.f32.html
// Request for guidance about this test to Smithy team: https://github.com/awslabs/smithy/pull/1040#discussion_r780418707
val params = Node.parse(
"""
{