Smithy 1.9/1.10 Upgrade (#618)

* smithy 1.9.1 upgrade & primitive encode/decode

This upgrades to Smithy 1.10, but the major change is a complete overhaul of how primitives are formatted and parsed. Primitive serialization was migrated and unified into Smithy Types with the end requirement of dealing with special float serialization semantics.

* Switch to Smithy Core S3 Customization Trait

Smithy 1.9.1 brings S3UnwrappedXmlOutput as a vended trait. This commit pulls in the new model & uses that trait.

* Fix clippy warnings

* Fix doc links

* fix kotlin formatting

* Fix s3 customization to use the operation shape

* Ensure that numbers in string don't parse as numbers

* remove unused itoa

* Apply suggestions from code review

Co-authored-by: John DiSanti <jdisanti@amazon.com>

* Fix tests, CR feedback

* rename parse to parse_smithy_primitive

* Fix some more clippy errors

* Update changelog

Co-authored-by: John DiSanti <jdisanti@amazon.com>
This commit is contained in:
Russell Cohen 2021-07-30 11:25:10 -04:00 committed by GitHub
parent 9fef09af72
commit f1a726c1d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 708 additions and 229 deletions

View File

@ -2,6 +2,7 @@
**New This Week** **New This Week**
- :bug: Correctly encode HTTP Checksums using base64 instead of hex. Fixes aws-sdk-rust#164. (#615) - :bug: Correctly encode HTTP Checksums using base64 instead of hex. Fixes aws-sdk-rust#164. (#615)
- (When complete) Add profile file provider for region (#594, #xyz) - (When complete) Add profile file provider for region (#594, #xyz)
- Overhaul serialization/deserialization of numeric/boolean types. This resolves issues around serialization of NaN/Infinity and should also reduce the number of allocations required during serialization. (#618)
## v0.18.1 (July 27th 2021) ## v0.18.1 (July 27th 2021)
* Remove timestreamwrite and timestreamquery from the generated services (#613) * Remove timestreamwrite and timestreamquery from the generated services (#613)

View File

@ -6,12 +6,8 @@
package software.amazon.smithy.rustsdk.customize.s3 package software.amazon.smithy.rustsdk.customize.s3
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.asType
@ -28,7 +24,6 @@ import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXmlFactory import software.amazon.smithy.rust.codegen.smithy.protocols.RestXmlFactory
import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.rustsdk.AwsRuntimeType import software.amazon.smithy.rustsdk.AwsRuntimeType
/** /**
@ -59,20 +54,6 @@ class S3Decorator : RustCodegenDecorator {
it + S3PubUse() it + S3PubUse()
} }
} }
override fun transformModel(service: ServiceShape, model: Model): Model {
return model.letIf(applies(service.id)) {
ModelTransformer.create().mapShapes(model) { shape ->
// Apply the S3UnwrappedXmlOutput customization to GetBucketLocation (more
// details on the S3UnwrappedXmlOutputTrait)
if (shape is StructureShape && shape.id == ShapeId.from("com.amazonaws.s3#GetBucketLocationOutput")) {
shape.toBuilder().addTrait(S3UnwrappedXmlOutputTrait()).build()
} else {
shape
}
}
}
}
} }
class S3(protocolConfig: ProtocolConfig) : RestXml(protocolConfig) { class S3(protocolConfig: ProtocolConfig) : RestXml(protocolConfig) {

View File

@ -4823,4 +4823,4 @@
} }
} }
} }
} }

View File

@ -935,4 +935,4 @@
} }
} }
} }
} }

View File

@ -8671,4 +8671,4 @@
} }
} }
} }
} }

View File

@ -1404,4 +1404,4 @@
} }
} }
} }
} }

View File

@ -3174,4 +3174,4 @@
} }
} }
} }
} }

View File

@ -3544,4 +3544,4 @@
} }
} }
} }
} }

View File

@ -5937,4 +5937,4 @@
} }
} }
} }
} }

View File

@ -3944,4 +3944,4 @@
} }
} }
} }
} }

View File

@ -1938,4 +1938,4 @@
} }
} }
} }
} }

View File

@ -6482,4 +6482,4 @@
} }
} }
} }
} }

View File

@ -1336,4 +1336,4 @@
} }
} }
} }
} }

View File

@ -1423,4 +1423,4 @@
"type": "boolean" "type": "boolean"
} }
} }
} }

View File

@ -4073,6 +4073,7 @@
"target": "com.amazonaws.s3#GetBucketLocationOutput" "target": "com.amazonaws.s3#GetBucketLocationOutput"
}, },
"traits": { "traits": {
"aws.customizations#s3UnwrappedXmlOutput": {},
"smithy.api#documentation": "<p>Returns the Region the bucket resides in. You set the bucket's Region using the\n <code>LocationConstraint</code> request parameter in a <code>CreateBucket</code>\n request. For more information, see <a href=\"https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateBucket.html\">CreateBucket</a>.</p>\n\n <p> To use this implementation of the operation, you must be the bucket owner.</p>\n\n <p>The following operations are related to <code>GetBucketLocation</code>:</p>\n <ul>\n <li>\n <p>\n <a href=\"https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html\">GetObject</a>\n </p>\n </li>\n <li>\n <p>\n <a href=\"https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateBucket.html\">CreateBucket</a>\n </p>\n </li>\n </ul>", "smithy.api#documentation": "<p>Returns the Region the bucket resides in. You set the bucket's Region using the\n <code>LocationConstraint</code> request parameter in a <code>CreateBucket</code>\n request. For more information, see <a href=\"https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateBucket.html\">CreateBucket</a>.</p>\n\n <p> To use this implementation of the operation, you must be the bucket owner.</p>\n\n <p>The following operations are related to <code>GetBucketLocation</code>:</p>\n <ul>\n <li>\n <p>\n <a href=\"https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html\">GetObject</a>\n </p>\n </li>\n <li>\n <p>\n <a href=\"https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateBucket.html\">CreateBucket</a>\n </p>\n </li>\n </ul>",
"smithy.api#http": { "smithy.api#http": {
"method": "GET", "method": "GET",

View File

@ -4041,4 +4041,4 @@
} }
} }
} }
} }

View File

@ -11,4 +11,4 @@ batch = { package = "aws-sdk-batch", path = "../../build/aws-sdk/batch" }
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = "0.2.18" tracing-subscriber = "0.2.18"

View File

@ -14,4 +14,4 @@ tokio = { version = "1", features = ["full"]}
base64 = "0.13.0" base64 = "0.13.0"
sha2 = "0.9.5" sha2 = "0.9.5"
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = "0.2.19" tracing-subscriber = "0.2.19"

View File

@ -11,4 +11,4 @@ kinesis = { package = "aws-sdk-kinesis", path = "../../build/aws-sdk/kinesis" }
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -11,4 +11,4 @@ medialive = { package = "aws-sdk-medialive", path = "../../build/aws-sdk/mediali
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -11,4 +11,4 @@ mediapackage = { package = "aws-sdk-mediapackage", path = "../../build/aws-sdk/m
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -12,4 +12,3 @@ aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -13,5 +13,3 @@ aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -51,4 +51,4 @@ where:
If the environment variable is not set, defaults to **us-west-2**. If the environment variable is not set, defaults to **us-west-2**.
- __-v__ enables displaying additional information. - __-v__ enables displaying additional information.
## ##

View File

@ -11,4 +11,4 @@ rds = {package = "aws-sdk-rds", path = "../../build/aws-sdk/rds"}
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = {version = "1", features = ["full"]} tokio = {version = "1", features = ["full"]}
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -11,4 +11,4 @@ rdsdata = {package = "aws-sdk-rdsdata", path = "../../build/aws-sdk/rdsdata"}
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = {version = "1", features = ["full"]} tokio = {version = "1", features = ["full"]}
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -15,4 +15,4 @@ tokio = { version = "1", features = ["full"] }
env_logger = "0.8.2" env_logger = "0.8.2"
chrono = "0.4.19" chrono = "0.4.19"
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = "0.2.18" tracing-subscriber = "0.2.18"

View File

@ -14,4 +14,3 @@ tokio = { version = "1", features = ["full"]}
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] } tracing-subscriber = { version = "0.2.16", features = ["fmt"] }

View File

@ -11,4 +11,4 @@ aws-sdk-snowball = { path = "../../build/aws-sdk/snowball" }
aws-types = { path = "../../build/aws-sdk/aws-types" } aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
tracing-subscriber = "0.2.18" tracing-subscriber = "0.2.18"

View File

@ -9,4 +9,4 @@ edition = "2018"
[dependencies] [dependencies]
sqs = { package = "aws-sdk-sqs", path = "../../build/aws-sdk/sqs" } sqs = { package = "aws-sdk-sqs", path = "../../build/aws-sdk/sqs" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tracing-subscriber = "0.2.18" tracing-subscriber = "0.2.18"

View File

@ -8,6 +8,17 @@ package software.amazon.smithy.rust.codegen.rustlang
import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.dq
/**
* Dereference [input]
*
* Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&'
*/
fun autoDeref(input: String) = if (input.startsWith("&")) {
input.removePrefix("&")
} else {
"*$input"
}
/** /**
* A hierarchy of types handled by Smithy codegen * A hierarchy of types handled by Smithy codegen
*/ */

View File

@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.smithy.generators
import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.ErrorTrait
@ -306,7 +308,21 @@ class HttpProtocolTestGenerator(
);""" );"""
) )
} else { } else {
rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""") when (protocolConfig.model.expectShape(member.target)) {
is DoubleShape, is FloatShape -> {
addUseImports(
RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "FloatEquals").toSymbol()
)
rust(
"""
assert!(parsed.$memberName.float_equals(&expected_output.$memberName),
"Unexpected value for `$memberName` {:?} vs. {:?}", expected_output.$memberName, parsed.$memberName);
"""
)
}
else ->
rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""")
}
} }
} }
} }
@ -428,7 +444,11 @@ class HttpProtocolTestGenerator(
private val RestXml = "aws.protocoltests.restxml#RestXml" private val RestXml = "aws.protocoltests.restxml#RestXml"
private val AwsQuery = "aws.protocoltests.query#AwsQuery" private val AwsQuery = "aws.protocoltests.query#AwsQuery"
private val Ec2Query = "aws.protocoltests.ec2#AwsEc2" private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
private val ExpectFail = setOf<FailingTest>() private val ExpectFail = setOf<FailingTest>(
FailingTest(
service = RestJson, id = "RestJsonHostWithPath", action = Action.Request
)
)
private val RunOnly: Set<String>? = null private val RunOnly: Set<String>? = null
// These tests are not even attempted to be generated, either because they will not compile // These tests are not even attempted to be generated, either because they will not compile

View File

@ -114,7 +114,18 @@ class Instantiator(
// Simple Shapes // Simple Shapes
is StringShape -> renderString(writer, shape, arg as StringNode) is StringShape -> renderString(writer, shape, arg as StringNode)
is NumberShape -> writer.write(arg.asNumberNode().get()) is NumberShape -> when (arg) {
is StringNode -> {
val numberSymbol = symbolProvider.toSymbol(shape)
// support Smithy custom values, such as Infinity
writer.rust(
"""<#T as #T>::parse_smithy_primitive(${arg.value.dq()}).expect("invalid string for number")""",
numberSymbol,
CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")
)
}
is NumberNode -> writer.write(arg.value)
}
is BooleanShape -> writer.write(arg.asBooleanNode().get().toString()) is BooleanShape -> writer.write(arg.asBooleanNode().get().toString())
is DocumentShape -> writer.rustBlock("") { is DocumentShape -> writer.rustBlock("") {
val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType()

View File

@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.smithy.generators.http package software.amazon.smithy.rust.codegen.smithy.generators.http
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.knowledge.HttpBindingIndex
@ -19,8 +20,10 @@ import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.assignment import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.autoDeref
import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
@ -36,6 +39,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.isPrimitive
fun HttpTrait.uriFormatString(): String { fun HttpTrait.uriFormatString(): String {
return uri.rustFormatString("/", "/") return uri.rustFormatString("/", "/")
@ -71,6 +75,7 @@ class RequestBindingGenerator(
) { ) {
private val index = HttpBindingIndex.of(model) private val index = HttpBindingIndex.of(model)
private val buildError = runtimeConfig.operationBuildError() private val buildError = runtimeConfig.operationBuildError()
private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")
constructor( constructor(
protocolConfig: ProtocolConfig, protocolConfig: ProtocolConfig,
@ -193,6 +198,9 @@ class RequestBindingGenerator(
ifSet(memberType, memberSymbol, "&self.$memberName") { field -> ifSet(memberType, memberSymbol, "&self.$memberName") { field ->
listForEach(memberType, field) { innerField, targetId -> listForEach(memberType, field) { innerField, targetId ->
val innerMemberType = model.expectShape(targetId) val innerMemberType = model.expectShape(targetId)
if (innerMemberType.isPrimitive()) {
rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder)
}
val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField) val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField)
val safeName = safeName("formatted") val safeName = safeName("formatted")
write("let $safeName = $formatted;") write("let $safeName = $formatted;")
@ -241,10 +249,10 @@ class RequestBindingGenerator(
target.isListShape || target.isMemberShape -> { target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level") throw IllegalArgumentException("lists should be handled at a higher level")
} }
else -> { target.isPrimitive() -> {
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_default")) "encoder.encode()"
"$func(&$targetName)"
} }
else -> throw CodegenException("unexpected shape: $target")
} }
} }
@ -263,12 +271,13 @@ class RequestBindingGenerator(
} }
val combinedArgs = listOf(formatString, *args.toTypedArray()) val combinedArgs = listOf(formatString, *args.toTypedArray())
writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null) writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null)
writer.rustBlock("fn uri_base(&self, output: &mut String) -> Result<(), #T>", runtimeConfig.operationBuildError()) { writer.rustBlock(
"fn uri_base(&self, output: &mut String) -> Result<(), #T>",
runtimeConfig.operationBuildError()
) {
httpTrait.uri.labels.map { label -> httpTrait.uri.labels.map { label ->
val member = inputShape.expectMember(label.content) val member = inputShape.expectMember(label.content)
assignment(local(member)) { serializeLabel(member, label, local(member))
serializeLabel(member, label)
}
} }
rust("""write!(output, ${combinedArgs.joinToString(", ")}).expect("formatting should succeed");""") rust("""write!(output, ${combinedArgs.joinToString(", ")}).expect("formatting should succeed");""")
rust("Ok(())") rust("Ok(())")
@ -374,13 +383,12 @@ class RequestBindingGenerator(
throw IllegalArgumentException("lists should be handled at a higher level") throw IllegalArgumentException("lists should be handled at a higher level")
} }
else -> { else -> {
val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_default")) "${writer.format(Encoder)}::from(${autoDeref(targetName)}).encode()"
"$func(&$targetName)"
} }
} }
} }
private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment) { private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment, outputVar: String) {
val target = model.expectShape(member.target) val target = model.expectShape(member.target)
val symbol = symbolProvider.toSymbol(member) val symbol = symbolProvider.toSymbol(member)
val buildError = { val buildError = {
@ -390,37 +398,37 @@ class RequestBindingGenerator(
"cannot be empty or unset" "cannot be empty or unset"
) )
} }
rustBlock("") { val input = safeName("input")
rust("let input = &self.${symbolProvider.toMemberName(member)};") rust("let $input = &self.${symbolProvider.toMemberName(member)};")
if (symbol.isOptional()) { if (symbol.isOptional()) {
rust("let input = input.as_ref().ok_or(${buildError()})?;") rust("let $input = $input.as_ref().ok_or(${buildError()})?;")
}
when {
target.isStringShape -> {
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string"))
rust("let $outputVar = $func($input, ${label.isGreedyLabel});")
} }
when { target.isTimestampShape -> {
target.isStringShape -> { val timestampFormat =
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string")) index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat)
rust("let formatted = $func(input, ${label.isGreedyLabel});") val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
} val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
target.isTimestampShape -> { rust("let $outputVar = $func(&$input, ${format(timestampFormatType)});")
val timestampFormat =
index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
rust("let formatted = $func(&input, ${format(timestampFormatType)});")
}
else -> {
val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_default"))
rust("let formatted = $func(input);")
}
} }
rust( else -> {
""" rust(
if formatted.is_empty() { "let mut ${outputVar}_encoder = #T::from(${autoDeref(input)}); let $outputVar = ${outputVar}_encoder.encode();",
Encoder
)
}
}
rust(
"""
if $outputVar.is_empty() {
return Err(${buildError()}) return Err(${buildError()})
} }
formatted
""" """
) )
}
} }
/** End URI generation **/ /** End URI generation **/
} }

View File

@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.isPrimitive
import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.toSnakeCase import software.amazon.smithy.rust.codegen.util.toSnakeCase
@ -238,17 +239,22 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
headerUtil, headerUtil,
timestampFormatType timestampFormatType
) )
} else if (coreShape.isPrimitive()) {
rust(
"let $parsedValue = #T::read_many_primitive::<${coreType.render(fullyQualified = true)}>(headers)?;",
headerUtil
)
} else { } else {
rust( rust(
"let $parsedValue: Vec<${coreType.render(true)}> = #T::read_many(headers)?;", "let $parsedValue: Vec<${coreType.render(fullyQualified = true)}> = #T::read_many_from_str(headers)?;",
headerUtil headerUtil
) )
if (coreShape.hasTrait<MediaTypeTrait>()) { if (coreShape.hasTrait<MediaTypeTrait>()) {
rustTemplate( rustTemplate(
"""let $parsedValue: std::result::Result<Vec<_>, _> = $parsedValue """let $parsedValue: std::result::Result<Vec<_>, _> = $parsedValue
.iter().map(|s| .iter().map(|s|
#{base_64_decode}(s).map_err(|_|#{header}::ParseError) #{base_64_decode}(s).map_err(|_|#{header}::ParseError::new_with_message("failed to decode base64"))
.and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError)) .and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError::new_with_message("base64 encoded data was not valid utf-8")))
).collect();""", ).collect();""",
"base_64_decode" to RuntimeType.Base64Decode(runtimeConfig), "base_64_decode" to RuntimeType.Base64Decode(runtimeConfig),
"header" to headerUtil "header" to headerUtil
@ -281,7 +287,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
else -> rustTemplate( else -> rustTemplate(
""" """
if $parsedValue.len() > 1 { if $parsedValue.len() > 1 {
Err(#{header_util}::ParseError) Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len())))
} else { } else {
let mut $parsedValue = $parsedValue; let mut $parsedValue = $parsedValue;
Ok($parsedValue.pop()) Ok($parsedValue.pop())

View File

@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.smithy.protocols.parse package software.amazon.smithy.rust.codegen.smithy.protocols.parse
import software.amazon.smithy.aws.traits.customizations.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.knowledge.HttpBindingIndex
@ -34,6 +35,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
@ -44,7 +46,6 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName
import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.hasTrait
@ -101,7 +102,8 @@ class XmlBindingTraitParserGenerator(
"XmlError" to xmlError, "XmlError" to xmlError,
"next_start_element" to smithyXml.member("decode::next_start_element"), "next_start_element" to smithyXml.member("decode::next_start_element"),
"try_data" to smithyXml.member("decode::try_data"), "try_data" to smithyXml.member("decode::try_data"),
"ScopedDecoder" to scopedDecoder "ScopedDecoder" to scopedDecoder,
"smithy_types" to CargoDependency.SmithyTypes(runtimeConfig).asType()
) )
private val model = protocolConfig.model private val model = protocolConfig.model
private val index = HttpBindingIndex.of(model) private val index = HttpBindingIndex.of(model)
@ -192,7 +194,7 @@ class XmlBindingTraitParserGenerator(
*codegenScope *codegenScope
) )
val context = OperationWrapperContext(operationShape, shapeName, xmlError) val context = OperationWrapperContext(operationShape, shapeName, xmlError)
if (outputShape.hasTrait<S3UnwrappedXmlOutputTrait>()) { if (operationShape.hasTrait<S3UnwrappedXmlOutputTrait>()) {
unwrappedResponseParser("builder", "decoder", "start_el", outputShape.members()) unwrappedResponseParser("builder", "decoder", "start_el", outputShape.members())
} else { } else {
writeOperationWrapper(context) { tagName -> writeOperationWrapper(context) { tagName ->
@ -561,8 +563,12 @@ class XmlBindingTraitParserGenerator(
is StringShape -> parseStringInner(shape, provider) is StringShape -> parseStringInner(shape, provider)
is NumberShape, is BooleanShape -> { is NumberShape, is BooleanShape -> {
rustBlock("") { rustBlock("") {
rust("use std::str::FromStr;") withBlockTemplate(
withBlock("#T::from_str(", ")", symbolProvider.toSymbol(shape)) { "<#{shape} as #{smithy_types}::primitive::Parse>::parse_smithy_primitive(",
")",
*codegenScope,
"shape" to symbolProvider.toSymbol(shape)
) {
provider() provider()
} }
rustTemplate( rustTemplate(

View File

@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.autoDeref
import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.render
import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlock
@ -195,17 +196,6 @@ class XmlBindingTraitSerializerGenerator(
rust("scope.finish();") rust("scope.finish();")
} }
/**
* Dereference [input]
*
* Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&'
*/
private fun autoDeref(input: String) = if (input.startsWith("&")) {
input.removePrefix("&")
} else {
"*$input"
}
private fun RustWriter.serializeRawMember(member: MemberShape, input: String) { private fun RustWriter.serializeRawMember(member: MemberShape, input: String) {
when (val shape = model.expectShape(member.target)) { when (val shape = model.expectShape(member.target)) {
is StringShape -> if (shape.hasTrait<EnumTrait>()) { is StringShape -> if (shape.hasTrait<EnumTrait>()) {
@ -213,8 +203,9 @@ class XmlBindingTraitSerializerGenerator(
} else { } else {
rust("$input.as_ref()") rust("$input.as_ref()")
} }
is NumberShape -> rust("$input.to_string().as_ref()") is BooleanShape, is NumberShape -> {
is BooleanShape -> rust("""if ${autoDeref(input)} { "true" } else { "false" }""") rust("#T::from(${autoDeref(input)}).encode()", CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder"))
}
is BlobShape -> rust("#T($input.as_ref()).as_ref()", RuntimeType.Base64Encode(runtimeConfig)) is BlobShape -> rust("#T($input.as_ref()).as_ref()", RuntimeType.Base64Encode(runtimeConfig))
is TimestampShape -> { is TimestampShape -> {
val timestampFormat = val timestampFormat =

View File

@ -1,34 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.smithy.traits
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait
/**
* S3's GetBucketLocation response shape can't be represented with Smithy's restXml protocol
* without customization. We add this trait to the S3 model at codegen time so that a different
* code path is taken in the XML deserialization codegen to generate code that parses the S3
* response shape correctly.
*
* From what the S3 model states, the generated parser would expect:
* ```
* <LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
* <LocationConstraint>us-west-2</LocationConstraint>
* </LocationConstraint>
* ```
*
* But S3 actually responds with:
* ```
* <LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">us-west-2</LocationConstraint>
* ```
*/
class S3UnwrappedXmlOutputTrait : AnnotationTrait(ID, Node.objectNode()) {
companion object {
val ID = ShapeId.from("smithy.api.internal#s3UnwrappedXmlOutputTrait")
}
}

View File

@ -7,7 +7,9 @@ package software.amazon.smithy.rust.codegen.util
import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.ShapeId
@ -65,3 +67,10 @@ inline fun <reified T : Trait> Shape.expectTrait(): T = expectTrait(T::class.jav
/** Kotlin sugar for getTrait() check. e.g. shape.getTrait<EnumTrait>() instead of shape.getTrait(EnumTrait::class.java) */ /** Kotlin sugar for getTrait() check. e.g. shape.getTrait<EnumTrait>() instead of shape.getTrait(EnumTrait::class.java) */
inline fun <reified T : Trait> Shape.getTrait(): T? = getTrait(T::class.java).orNull() inline fun <reified T : Trait> Shape.getTrait(): T? = getTrait(T::class.java).orNull()
fun Shape.isPrimitive(): Boolean {
return when (this) {
is NumberShape, is BooleanShape -> true
else -> false
}
}

View File

@ -6,7 +6,7 @@
kotlin.code.style=official kotlin.code.style=official
# codegen # codegen
smithyVersion=1.8.0 smithyVersion=1.10.0
# kotlin # kotlin
kotlinVersion=1.4.21 kotlinVersion=1.4.21

View File

@ -20,4 +20,4 @@ are to allow this crate to be compilable and testable in isolation, no client co
[dev-dependencies] [dev-dependencies]
proptest = "1" proptest = "1"
regex = "1" regex = "1"

View File

@ -15,6 +15,39 @@ use std::fmt::{self, Debug};
use thiserror::Error; use thiserror::Error;
use urlencoded::try_url_encoded_form_equivalent; use urlencoded::try_url_encoded_form_equivalent;
/// Helper trait for tests for float comparisons
///
/// This trait differs in float's default `PartialEq` implementation by considering all `NaN` values to
/// be equal.
pub trait FloatEquals {
fn float_equals(&self, other: &Self) -> bool;
}
impl FloatEquals for f64 {
fn float_equals(&self, other: &Self) -> bool {
(self.is_nan() && other.is_nan()) || self.eq(other)
}
}
impl FloatEquals for f32 {
fn float_equals(&self, other: &Self) -> bool {
(self.is_nan() && other.is_nan()) || self.eq(other)
}
}
impl<T> FloatEquals for Option<T>
where
T: FloatEquals,
{
fn float_equals(&self, other: &Self) -> bool {
match (self, other) {
(Some(this), Some(other)) => this.float_equals(other),
(None, None) => true,
_else => false,
}
}
}
#[derive(Debug, PartialEq, Eq, Error)] #[derive(Debug, PartialEq, Eq, Error)]
pub enum ProtocolTestFailure { pub enum ProtocolTestFailure {
#[error("missing query param: expected `{expected}`, found {found:?}")] #[error("missing query param: expected `{expected}`, found {found:?}")]
@ -326,7 +359,7 @@ fn try_json_eq(actual: &str, expected: &str) -> Result<(), ProtocolTestFailure>
mod tests { mod tests {
use crate::{ use crate::{
forbid_headers, forbid_query_params, require_headers, require_query_params, validate_body, forbid_headers, forbid_query_params, require_headers, require_query_params, validate_body,
validate_headers, validate_query_string, MediaType, ProtocolTestFailure, validate_headers, validate_query_string, FloatEquals, MediaType, ProtocolTestFailure,
}; };
use http::Request; use http::Request;
@ -472,4 +505,20 @@ mod tests {
validate_body(&expected, expected, MediaType::from("something/else")) validate_body(&expected, expected, MediaType::from("something/else"))
.expect("inputs matched exactly") .expect("inputs matched exactly")
} }
#[test]
fn test_float_equals() {
let a = f64::NAN;
let b = f64::NAN;
assert_ne!(a, b);
assert!(a.float_equals(&b));
assert!(!a.float_equals(&5_f64));
assert!(5.0.float_equals(&5.0));
assert!(!5.0.float_equals(&5.1));
assert!(f64::INFINITY.float_equals(&f64::INFINITY));
assert!(!f64::INFINITY.float_equals(&f64::NEG_INFINITY));
assert!(f64::NEG_INFINITY.float_equals(&f64::NEG_INFINITY));
}
} }

View File

@ -8,18 +8,40 @@
use http::header::{HeaderName, ValueIter}; use http::header::{HeaderName, ValueIter};
use http::HeaderValue; use http::HeaderValue;
use smithy_types::instant::Format; use smithy_types::instant::Format;
use smithy_types::primitive::Parse;
use smithy_types::Instant; use smithy_types::Instant;
use std::borrow::Cow;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::str::FromStr; use std::str::FromStr;
#[derive(Debug)] #[derive(Debug, Eq, PartialEq)]
pub struct ParseError; #[non_exhaustive]
pub struct ParseError {
message: Option<Cow<'static, str>>,
}
impl ParseError {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self { message: None }
}
pub fn new_with_message(message: impl Into<Cow<'static, str>>) -> Self {
Self {
message: Some(message.into()),
}
}
}
impl Display for ParseError { impl Display for ParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Output failed to parse in headers") write!(f, "Output failed to parse in headers")?;
if let Some(message) = &self.message {
write!(f, ". {}", message)?;
}
Ok(())
} }
} }
@ -35,9 +57,13 @@ pub fn many_dates(
) -> Result<Vec<Instant>, ParseError> { ) -> Result<Vec<Instant>, ParseError> {
let mut out = vec![]; let mut out = vec![];
for header in values { for header in values {
let mut header = header.to_str().map_err(|_| ParseError)?; let mut header = header
.to_str()
.map_err(|_| ParseError::new_with_message("header was not valid utf-8 string"))?;
while !header.is_empty() { while !header.is_empty() {
let (v, next) = Instant::read(header, format, ',').map_err(|_| ParseError)?; let (v, next) = Instant::read(header, format, ',').map_err(|err| {
ParseError::new_with_message(format!("header could not be parsed as date: {}", err))
})?;
out.push(v); out.push(v);
header = next; header = next;
} }
@ -56,16 +82,36 @@ pub fn headers_for_prefix<'a>(
.map(move |h| (&h.as_str()[key.len()..], h)) .map(move |h| (&h.as_str()[key.len()..], h))
} }
pub fn read_many_from_str<T: FromStr>(
values: ValueIter<HeaderValue>,
) -> Result<Vec<T>, ParseError> {
read_many(values, |v: &str| {
v.parse()
.map_err(|_err| ParseError::new_with_message("failed during FromString conversion"))
})
}
pub fn read_many_primitive<T: Parse>(values: ValueIter<HeaderValue>) -> Result<Vec<T>, ParseError> {
read_many(values, |v: &str| {
T::parse_smithy_primitive(v).map_err(|primitive| {
ParseError::new_with_message(format!(
"failed reading a list of primitives: {}",
primitive
))
})
})
}
/// Read many comma / header delimited values from HTTP headers for `FromStr` types /// Read many comma / header delimited values from HTTP headers for `FromStr` types
pub fn read_many<T>(values: ValueIter<HeaderValue>) -> Result<Vec<T>, ParseError> fn read_many<T>(
where values: ValueIter<HeaderValue>,
T: FromStr, f: impl Fn(&str) -> Result<T, ParseError>,
{ ) -> Result<Vec<T>, ParseError> {
let mut out = vec![]; let mut out = vec![];
for header in values { for header in values {
let mut header = header.as_bytes(); let mut header = header.as_bytes();
while !header.is_empty() { while !header.is_empty() {
let (v, next) = read_one::<T>(&header)?; let (v, next) = read_one(&header, &f)?;
out.push(v); out.push(v);
header = next; header = next;
} }
@ -83,10 +129,15 @@ pub fn one_or_none<T: FromStr>(
Some(v) => v, Some(v) => v,
None => return Ok(None), None => return Ok(None),
}; };
let value = std::str::from_utf8(first.as_bytes()).map_err(|_| ParseError)?; let value = std::str::from_utf8(first.as_bytes())
.map_err(|_| ParseError::new_with_message("invalid utf-8"))?;
match values.next() { match values.next() {
None => T::from_str(value.trim()).map_err(|_| ParseError).map(Some), None => T::from_str(value.trim())
Some(_) => Err(ParseError), .map_err(|_| ParseError::new())
.map(Some),
Some(_) => Err(ParseError::new_with_message(
"expected a single value but found multiple",
)),
} }
} }
@ -107,13 +158,14 @@ pub fn set_header_if_absent(
} }
/// Read one comma delimited value for `FromStr` types /// Read one comma delimited value for `FromStr` types
fn read_one<T>(s: &[u8]) -> Result<(T, &[u8]), ParseError> fn read_one<'a, T>(
where s: &'a [u8],
T: FromStr, f: &impl Fn(&str) -> Result<T, ParseError>,
{ ) -> Result<(T, &'a [u8]), ParseError> {
let (head, rest) = split_at_delim(s); let (head, rest) = split_at_delim(s);
let head = std::str::from_utf8(head).map_err(|_| ParseError)?; let head = std::str::from_utf8(head)
Ok((T::from_str(head.trim()).map_err(|_| ParseError)?, rest)) .map_err(|_| ParseError::new_with_message("header was not valid utf8"))?;
Ok((f(head.trim())?, rest))
} }
fn split_at_delim(s: &[u8]) -> (&[u8], &[u8]) { fn split_at_delim(s: &[u8]) -> (&[u8], &[u8]) {
@ -128,13 +180,15 @@ fn then_delim(s: &[u8]) -> Result<&[u8], ParseError> {
} else if s.starts_with(b",") { } else if s.starts_with(b",") {
Ok(&s[1..]) Ok(&s[1..])
} else { } else {
Err(ParseError) Err(ParseError::new_with_message("expected delimiter `,`"))
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::header::{headers_for_prefix, read_many, set_header_if_absent, ParseError}; use crate::header::{
headers_for_prefix, read_many_primitive, set_header_if_absent, ParseError,
};
use std::collections::HashMap; use std::collections::HashMap;
#[test] #[test]
@ -153,6 +207,27 @@ mod test {
); );
} }
#[test]
fn parse_floats() {
let test_request = http::Request::builder()
.header("X-Float-Multi", "0.0,Infinity,-Infinity,5555.5")
.header("X-Float-Error", "notafloat")
.body(())
.unwrap();
assert_eq!(
read_many_primitive::<f32>(test_request.headers().get_all("X-Float-Multi").iter())
.expect("valid"),
vec![0.0, f32::INFINITY, f32::NEG_INFINITY, 5555.5]
);
assert_eq!(
read_many_primitive::<f32>(test_request.headers().get_all("X-Float-Error").iter())
.expect_err("invalid"),
ParseError::new_with_message(
"failed reading a list of primitives: failed to parse input as f32"
)
)
}
#[test] #[test]
fn read_many_bools() { fn read_many_bools() {
let test_request = http::Request::builder() let test_request = http::Request::builder()
@ -164,47 +239,50 @@ mod test {
.body(()) .body(())
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
read_many::<bool>(test_request.headers().get_all("X-Bool-Multi").iter()) read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Multi").iter())
.expect("valid"), .expect("valid"),
vec![true, false, true] vec![true, false, true]
); );
assert_eq!( assert_eq!(
read_many::<bool>(test_request.headers().get_all("X-Bool").iter()).unwrap(), read_many_primitive::<bool>(test_request.headers().get_all("X-Bool").iter()).unwrap(),
vec![true] vec![true]
); );
assert_eq!( assert_eq!(
read_many::<bool>(test_request.headers().get_all("X-Bool-Single").iter()).unwrap(), read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Single").iter())
.unwrap(),
vec![true, false, true, true] vec![true, false, true, true]
); );
read_many::<bool>(test_request.headers().get_all("X-Bool-Invalid").iter()) read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Invalid").iter())
.expect_err("invalid"); .expect_err("invalid");
} }
#[test] #[test]
fn read_many_u16() { fn check_read_many_i16() {
let test_request = http::Request::builder() let test_request = http::Request::builder()
.header("X-Multi", "123,456") .header("X-Multi", "123,456")
.header("X-Multi", "789") .header("X-Multi", "789")
.header("X-Num", "777") .header("X-Num", "777")
.header("X-Num-Invalid", "12ef3") .header("X-Num-Invalid", "12ef3")
.header("X-Num-Single", "1,2,3,4,5") .header("X-Num-Single", "1,2,3,-4,5")
.body(()) .body(())
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
read_many::<u16>(test_request.headers().get_all("X-Multi").iter()).expect("valid"), read_many_primitive::<i16>(test_request.headers().get_all("X-Multi").iter())
.expect("valid"),
vec![123, 456, 789] vec![123, 456, 789]
); );
assert_eq!( assert_eq!(
read_many::<u16>(test_request.headers().get_all("X-Num").iter()).unwrap(), read_many_primitive::<i16>(test_request.headers().get_all("X-Num").iter()).unwrap(),
vec![777] vec![777]
); );
assert_eq!( assert_eq!(
read_many::<u16>(test_request.headers().get_all("X-Num-Single").iter()).unwrap(), read_many_primitive::<i16>(test_request.headers().get_all("X-Num-Single").iter())
vec![1, 2, 3, 4, 5] .unwrap(),
vec![1, 2, 3, -4, 5]
); );
read_many::<u16>(test_request.headers().get_all("X-Num-Invalid").iter()) read_many_primitive::<i16>(test_request.headers().get_all("X-Num-Invalid").iter())
.expect_err("invalid"); .expect_err("invalid");
} }
@ -217,15 +295,14 @@ mod test {
.header("X-Prefix-C", "777") .header("X-Prefix-C", "777")
.body(()) .body(())
.unwrap(); .unwrap();
let resp: Result<HashMap<String, Vec<u16>>, ParseError> = let resp: Result<HashMap<String, Vec<i16>>, ParseError> =
headers_for_prefix(test_request.headers(), "X-Prefix-") headers_for_prefix(test_request.headers(), "X-Prefix-")
.map(|(key, header_name)| { .map(|(key, header_name)| {
let values = test_request.headers().get_all(header_name); let values = test_request.headers().get_all(header_name);
read_many(values.iter()).map(|v| (key.to_string(), v)) read_many_primitive(values.iter()).map(|v| (key.to_string(), v))
}) })
.collect(); .collect();
let resp = resp.expect("valid"); let resp = resp.expect("valid");
println!("{:?}", resp); assert_eq!(resp.get("a"), Some(&vec![123_i16, 456_i16]));
assert_eq!(resp.get("a"), Some(&vec![123_u16, 456_u16]));
} }
} }

View File

@ -9,14 +9,9 @@
use crate::urlencode::BASE_SET; use crate::urlencode::BASE_SET;
use percent_encoding::AsciiSet; use percent_encoding::AsciiSet;
use smithy_types::Instant; use smithy_types::Instant;
use std::fmt::Debug;
const GREEDY: &AsciiSet = &BASE_SET.remove(b'/'); const GREEDY: &AsciiSet = &BASE_SET.remove(b'/');
pub fn fmt_default<T: Debug>(t: T) -> String {
format!("{:?}", t)
}
pub fn fmt_string<T: AsRef<str>>(t: T, greedy: bool) -> String { pub fn fmt_string<T: AsRef<str>>(t: T, greedy: bool) -> String {
let uri_set = if greedy { GREEDY } else { BASE_SET }; let uri_set = if greedy { GREEDY } else { BASE_SET };
percent_encoding::utf8_percent_encode(t.as_ref(), &uri_set).to_string() percent_encoding::utf8_percent_encode(t.as_ref(), &uri_set).to_string()

View File

@ -8,11 +8,6 @@ use percent_encoding::utf8_percent_encode;
/// Formatting values into the query string as specified in /// Formatting values into the query string as specified in
/// [httpQuery](https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#httpquery-trait) /// [httpQuery](https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#httpquery-trait)
use smithy_types::Instant; use smithy_types::Instant;
use std::fmt::Debug;
pub fn fmt_default<T: Debug>(t: T) -> String {
format!("{:?}", t)
}
pub fn fmt_string<T: AsRef<str>>(t: T) -> String { pub fn fmt_string<T: AsRef<str>>(t: T) -> String {
utf8_percent_encode(t.as_ref(), BASE_SET).to_string() utf8_percent_encode(t.as_ref(), BASE_SET).to_string()

View File

@ -5,8 +5,6 @@ authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "John DiSanti <jdisant
edition = "2018" edition = "2018"
[dependencies] [dependencies]
itoa = "0.4"
ryu = "1.0"
smithy-types = { path = "../smithy-types" } smithy-types = { path = "../smithy-types" }
[dev-dependencies] [dev-dependencies]

View File

@ -326,7 +326,11 @@ impl<'a> JsonTokenIterator<'a> {
offset, offset,
value: if floating { value: if floating {
Number::Float( Number::Float(
f64::from_str(&number_str).map_err(|_| self.error_at(start, InvalidNumber))?, f64::from_str(&number_str)
.map_err(|_| self.error_at(start, InvalidNumber))
.and_then(|f| {
must_be_finite(f).map_err(|_| self.error_at(start, InvalidNumber))
})?,
) )
} else if negative { } else if negative {
// If the negative value overflows, then stuff it into an f64 // If the negative value overflows, then stuff it into an f64
@ -484,6 +488,22 @@ impl<'a> Iterator for JsonTokenIterator<'a> {
} }
} }
fn must_be_finite(f: f64) -> Result<f64, ()> {
if f.is_finite() {
Ok(f)
} else {
Err(())
}
}
fn must_not_be_finite(f: f64) -> Result<f64, ()> {
if !f.is_finite() {
Ok(f)
} else {
Err(())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::deserialize::token::test::{ use crate::deserialize::token::test::{

View File

@ -9,7 +9,9 @@ use smithy_types::instant::Format;
use smithy_types::{base64, Blob, Document, Instant, Number}; use smithy_types::{base64, Blob, Document, Instant, Number};
use std::borrow::Cow; use std::borrow::Cow;
use crate::deserialize::must_not_be_finite;
pub use crate::escape::Error as EscapeError; pub use crate::escape::Error as EscapeError;
use smithy_types::primitive::Parse;
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::Peekable; use std::iter::Peekable;
@ -151,9 +153,45 @@ macro_rules! expect_value_or_null_fn {
} }
expect_value_or_null_fn!(expect_bool_or_null, ValueBool, bool, "Expects a [Token::ValueBool] or [Token::ValueNull], and returns the bool value if it's not null."); expect_value_or_null_fn!(expect_bool_or_null, ValueBool, bool, "Expects a [Token::ValueBool] or [Token::ValueNull], and returns the bool value if it's not null.");
expect_value_or_null_fn!(expect_number_or_null, ValueNumber, Number, "Expects a [Token::ValueNumber] or [Token::ValueNull], and returns the [Number] value if it's not null.");
expect_value_or_null_fn!(expect_string_or_null, ValueString, EscapedStr, "Expects a [Token::ValueString] or [Token::ValueNull], and returns the [EscapedStr] value if it's not null."); expect_value_or_null_fn!(expect_string_or_null, ValueString, EscapedStr, "Expects a [Token::ValueString] or [Token::ValueNull], and returns the [EscapedStr] value if it's not null.");
/// Expects a [Token::ValueString], [Token::ValueNumber] or [Token::ValueNull].
///
/// If the value is a string, it MUST be `Infinity`, `-Infinity` or `Nan`.
/// If the value is a number, it is returned directly
pub fn expect_number_or_null(
token: Option<Result<Token<'_>, Error>>,
) -> Result<Option<Number>, Error> {
match token.transpose()? {
Some(Token::ValueNull { .. }) => Ok(None),
Some(Token::ValueNumber { value, .. }) => Ok(Some(value)),
Some(Token::ValueString { value, offset }) => match value.to_unescaped() {
Err(err) => Err(Error::new(
ErrorReason::Custom(format!("expected a valid string, escape was invalid: {}", err).into()), Some(offset.0))
),
Ok(v) => f64::parse_smithy_primitive(v.as_ref())
// disregard the exact error
.map_err(|_|())
// only infinite / NaN can be used as strings
.and_then(must_not_be_finite)
.map(|float| Some(smithy_types::Number::Float(float)))
// convert to a helpful error
.map_err(|_| {
Error::new(
ErrorReason::Custom(Cow::Owned(format!(
"only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `{}`",
v
))),
Some(offset.0),
)
}),
},
_ => Err(Error::custom(
"expected ValueString, ValueNumber, or ValueNull",
)),
}
}
/// Expects a [Token::ValueString] or [Token::ValueNull]. If the value is a string, it interprets it as a base64 encoded [Blob] value. /// Expects a [Token::ValueString] or [Token::ValueNull]. If the value is a string, it interprets it as a base64 encoded [Blob] value.
pub fn expect_blob_or_null(token: Option<Result<Token<'_>, Error>>) -> Result<Option<Blob>, Error> { pub fn expect_blob_or_null(token: Option<Result<Token<'_>, Error>>) -> Result<Option<Blob>, Error> {
Ok(match expect_string_or_null(token)? { Ok(match expect_string_or_null(token)? {
@ -386,6 +424,15 @@ pub mod test {
)) ))
} }
#[test]
fn test_non_finite_floats() {
let mut tokens = json_token_iter(b"inf");
tokens
.next()
.expect("there is a token")
.expect_err("but it is invalid, ensure that Rust float boundary cases don't parse");
}
#[test] #[test]
fn mismatched_braces() { fn mismatched_braces() {
// The skip_value function doesn't need to explicitly handle these cases since // The skip_value function doesn't need to explicitly handle these cases since
@ -466,9 +513,27 @@ pub mod test {
expect_number_or_null(value_number(0, Number::PosInt(5))) expect_number_or_null(value_number(0, Number::PosInt(5)))
); );
assert_eq!( assert_eq!(
Err(Error::custom("expected ValueNumber or ValueNull")), Err(Error::custom(
"expected ValueString, ValueNumber, or ValueNull"
)),
expect_number_or_null(value_bool(0, true)) expect_number_or_null(value_bool(0, true))
); );
assert_eq!(
Ok(Some(Number::Float(f64::INFINITY))),
expect_number_or_null(value_string(0, "Infinity"))
);
assert_eq!(
Err(Error::new(ErrorReason::Custom("only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `123`".into()), Some(0))),
expect_number_or_null(value_string(0, "123"))
);
match expect_number_or_null(value_string(0, "NaN")) {
Ok(Some(Number::Float(v))) if v.is_nan() => {
// ok
}
not_ok => {
panic!("expected nan, found: {:?}", not_ok)
}
}
} }
#[test] #[test]
@ -505,8 +570,14 @@ pub mod test {
Ok(Some(Instant::from_f64(1445412480.0))), Ok(Some(Instant::from_f64(1445412480.0))),
expect_timestamp_or_null(value_string(0, "2015-10-21T07:28:00Z"), Format::DateTime) expect_timestamp_or_null(value_string(0, "2015-10-21T07:28:00Z"), Format::DateTime)
); );
let err = Error::new(
ErrorReason::Custom(
"only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `wrong`".into(),
),
Some(0),
);
assert_eq!( assert_eq!(
Err(Error::custom("expected ValueNumber or ValueNull")), Err(err),
expect_timestamp_or_null(value_string(0, "wrong"), Format::EpochSeconds) expect_timestamp_or_null(value_string(0, "wrong"), Format::EpochSeconds)
); );
assert_eq!( assert_eq!(

View File

@ -5,6 +5,7 @@
use crate::escape::escape_string; use crate::escape::escape_string;
use smithy_types::instant::Format; use smithy_types::instant::Format;
use smithy_types::primitive::Encoder;
use smithy_types::{Document, Instant, Number}; use smithy_types::{Document, Instant, Number};
use std::borrow::Cow; use std::borrow::Cow;
@ -76,19 +77,18 @@ impl<'a> JsonValueWriter<'a> {
match value { match value {
Number::PosInt(value) => { Number::PosInt(value) => {
// itoa::Buffer is a fixed-size stack allocation, so this is cheap // itoa::Buffer is a fixed-size stack allocation, so this is cheap
self.output.push_str(itoa::Buffer::new().format(value)); self.output.push_str(Encoder::from(value).encode());
} }
Number::NegInt(value) => { Number::NegInt(value) => {
self.output.push_str(itoa::Buffer::new().format(value)); self.output.push_str(Encoder::from(value).encode());
} }
Number::Float(value) => { Number::Float(value) => {
// If the value is NaN, Infinity, or -Infinity let mut encoder: Encoder = value.into();
if value.is_nan() || value.is_infinite() { // Nan / infinite values actually get written in quotes as a string value
self.output.push_str("null"); if value.is_infinite() || value.is_nan() {
self.string_unchecked(encoder.encode())
} else { } else {
// ryu::Buffer is a fixed-size stack allocation, so this is cheap self.output.push_str(encoder.encode())
self.output
.push_str(ryu::Buffer::new().format_finite(value));
} }
} }
} }
@ -394,18 +394,15 @@ mod tests {
assert_eq!("10000000000.0", format_test_number(Number::Float(1e10))); assert_eq!("10000000000.0", format_test_number(Number::Float(1e10)));
assert_eq!("-1.2", format_test_number(Number::Float(-1.2))); assert_eq!("-1.2", format_test_number(Number::Float(-1.2)));
// JSON doesn't support NaN, Infinity, or -Infinity, so we're matching // Smithy has specific behavior for infinity & NaN
// the behavior of the serde_json crate in these cases. // the behavior of the serde_json crate in these cases.
assert_eq!("\"NaN\"", format_test_number(Number::Float(f64::NAN)));
assert_eq!( assert_eq!(
serde_json::to_string(&f64::NAN).unwrap(), "\"Infinity\"",
format_test_number(Number::Float(f64::NAN))
);
assert_eq!(
serde_json::to_string(&f64::INFINITY).unwrap(),
format_test_number(Number::Float(f64::INFINITY)) format_test_number(Number::Float(f64::INFINITY))
); );
assert_eq!( assert_eq!(
serde_json::to_string(&f64::NEG_INFINITY).unwrap(), "\"-Infinity\"",
format_test_number(Number::Float(f64::NEG_INFINITY)) format_test_number(Number::Float(f64::NEG_INFINITY))
); );
} }

View File

@ -5,7 +5,5 @@ authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "John DiSanti <jdisant
edition = "2018" edition = "2018"
[dependencies] [dependencies]
itoa = "0.4"
ryu = "1.0"
urlencoding = "1.3" urlencoding = "1.3"
smithy-types = { path = "../smithy-types" } smithy-types = { path = "../smithy-types" }

View File

@ -6,6 +6,7 @@
//! Abstractions for the Smithy AWS Query protocol //! Abstractions for the Smithy AWS Query protocol
use smithy_types::instant::Format; use smithy_types::instant::Format;
use smithy_types::primitive::Encoder;
use smithy_types::{Instant, Number}; use smithy_types::{Instant, Number};
use std::borrow::Cow; use std::borrow::Cow;
use urlencoding::encode; use urlencoding::encode;
@ -168,20 +169,12 @@ impl<'a> QueryValueWriter<'a> {
match value { match value {
Number::PosInt(value) => { Number::PosInt(value) => {
// itoa::Buffer is a fixed-size stack allocation, so this is cheap // itoa::Buffer is a fixed-size stack allocation, so this is cheap
self.string(itoa::Buffer::new().format(value)); self.string(Encoder::from(value).encode());
} }
Number::NegInt(value) => { Number::NegInt(value) => {
self.string(itoa::Buffer::new().format(value)); self.string(Encoder::from(value).encode());
}
Number::Float(value) => {
// If the value is NaN, Infinity, or -Infinity
if value.is_nan() || value.is_infinite() {
self.string("");
} else {
// ryu::Buffer is a fixed-size stack allocation, so this is cheap
self.string(ryu::Buffer::new().format_finite(value));
}
} }
Number::Float(value) => self.string(Encoder::from(value).encode()),
} }
} }
@ -378,9 +371,9 @@ mod tests {
&Version=1.0\ &Version=1.0\
&PosInt=5\ &PosInt=5\
&NegInt=-5\ &NegInt=-5\
&Infinity=\ &Infinity=Infinity\
&NegInfinity=\ &NegInfinity=-Infinity\
&NaN=\ &NaN=NaN\
&Floating=5.2\ &Floating=5.2\
", ",
out out

View File

@ -11,6 +11,8 @@ default = ["chrono-conversions"]
[dependencies] [dependencies]
chrono = { version = "0.4", default-features = false, features = [] } chrono = { version = "0.4", default-features = false, features = [] }
ryu = "1.0.5"
itoa = "0.4.0"
[dev-dependencies] [dev-dependencies]
base64 = "0.13.0" base64 = "0.13.0"

View File

@ -5,6 +5,7 @@
pub mod base64; pub mod base64;
pub mod instant; pub mod instant;
pub mod primitive;
pub mod retry; pub mod retry;
use std::collections::HashMap; use std::collections::HashMap;

View File

@ -0,0 +1,276 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
//! Utilities for formatting and parsing primitives
//!
//! Smithy protocols have specific behavior for serializing
//! & deserializing floats, specifically:
//! - NaN should be serialized as `NaN`
//! - Positive infinity should be serialized as `Infinity`
//! - Negative infinity should be serialized as `-Infinity`
//!
//! This module defines the [`Parse`](Parse) trait which
//! enables parsing primitive values (numbers & booleans) that follow
//! these rules and [`Encoder`](Encoder), a struct that enables
//! allocation-free serialization.
//!
//! # Examples
//! ## Parsing
//! ```rust
//! use smithy_types::primitive::Parse;
//! let parsed = f64::parse_smithy_primitive("123.4").expect("valid float");
//! ```
//!
//! ## Encoding
//! ```
//! use smithy_types::primitive::Encoder;
//! assert_eq!("123.4", Encoder::from(123.4).encode());
//! assert_eq!("Infinity", Encoder::from(f64::INFINITY).encode());
//! assert_eq!("true", Encoder::from(true).encode());
//! ```
use crate::primitive::private::Sealed;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
/// An error during primitive parsing
#[non_exhaustive]
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct PrimitiveParseError(&'static str);
impl Display for PrimitiveParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to parse input as {}", self.0)
}
}
impl Error for PrimitiveParseError {}
/// Sealed trait for custom parsing of primitive types
pub trait Parse: Sealed {
fn parse_smithy_primitive(input: &str) -> Result<Self, PrimitiveParseError>
where
Self: Sized;
}
mod private {
pub trait Sealed {}
impl Sealed for i8 {}
impl Sealed for i16 {}
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for u64 {}
impl Sealed for bool {}
}
macro_rules! parse_from_str {
($t: ty) => {
impl Parse for $t {
fn parse_smithy_primitive(input: &str) -> Result<Self, PrimitiveParseError> {
FromStr::from_str(input).map_err(|_| PrimitiveParseError(stringify!($t)))
}
}
};
}
parse_from_str!(bool);
parse_from_str!(i8);
parse_from_str!(i16);
parse_from_str!(i32);
parse_from_str!(i64);
impl Parse for f32 {
fn parse_smithy_primitive(input: &str) -> Result<Self, PrimitiveParseError> {
float::parse_f32(input).map_err(|_| PrimitiveParseError("f32"))
}
}
impl Parse for f64 {
fn parse_smithy_primitive(input: &str) -> Result<Self, PrimitiveParseError> {
float::parse_f64(input).map_err(|_| PrimitiveParseError("f64"))
}
}
/// Primitive Type Encoder
///
/// This type implements `From<T>` for all Smithy primitive types.
#[non_exhaustive]
pub enum Encoder {
#[non_exhaustive]
Bool(bool),
#[non_exhaustive]
I8(i8, itoa::Buffer),
#[non_exhaustive]
I16(i16, itoa::Buffer),
#[non_exhaustive]
I32(i32, itoa::Buffer),
#[non_exhaustive]
I64(i64, itoa::Buffer),
#[non_exhaustive]
U64(u64, itoa::Buffer),
#[non_exhaustive]
F32(f32, ryu::Buffer),
#[non_exhaustive]
F64(f64, ryu::Buffer),
}
impl Encoder {
pub fn encode(&mut self) -> &str {
match self {
Encoder::Bool(true) => "true",
Encoder::Bool(false) => "false",
Encoder::I8(v, buf) => buf.format(*v),
Encoder::I16(v, buf) => buf.format(*v),
Encoder::I32(v, buf) => buf.format(*v),
Encoder::I64(v, buf) => buf.format(*v),
Encoder::U64(v, buf) => buf.format(*v),
Encoder::F32(v, buf) => {
if v.is_nan() {
float::NAN
} else if *v == f32::INFINITY {
float::INFINITY
} else if *v == f32::NEG_INFINITY {
float::NEG_INFINITY
} else {
buf.format_finite(*v)
}
}
Encoder::F64(v, buf) => {
if v.is_nan() {
float::NAN
} else if *v == f64::INFINITY {
float::INFINITY
} else if *v == f64::NEG_INFINITY {
float::NEG_INFINITY
} else {
buf.format_finite(*v)
}
}
}
}
}
impl From<bool> for Encoder {
fn from(input: bool) -> Self {
Self::Bool(input)
}
}
impl From<i8> for Encoder {
fn from(input: i8) -> Self {
Self::I8(input, itoa::Buffer::new())
}
}
impl From<i16> for Encoder {
fn from(input: i16) -> Self {
Self::I16(input, itoa::Buffer::new())
}
}
impl From<i32> for Encoder {
fn from(input: i32) -> Self {
Self::I32(input, itoa::Buffer::new())
}
}
impl From<i64> for Encoder {
fn from(input: i64) -> Self {
Self::I64(input, itoa::Buffer::new())
}
}
impl From<u64> for Encoder {
fn from(input: u64) -> Self {
Self::U64(input, itoa::Buffer::new())
}
}
impl From<f32> for Encoder {
fn from(input: f32) -> Self {
Self::F32(input, ryu::Buffer::new())
}
}
impl From<f64> for Encoder {
fn from(input: f64) -> Self {
Self::F64(input, ryu::Buffer::new())
}
}
mod float {
use std::num::ParseFloatError;
pub const INFINITY: &str = "Infinity";
pub const NEG_INFINITY: &str = "-Infinity";
pub const NAN: &str = "NaN";
pub fn parse_f32(data: &str) -> Result<f32, ParseFloatError> {
match data {
INFINITY => Ok(f32::INFINITY),
NEG_INFINITY => Ok(f32::NEG_INFINITY),
NAN => Ok(f32::NAN),
other => other.parse::<f32>(),
}
}
pub fn parse_f64(data: &str) -> Result<f64, ParseFloatError> {
match data {
INFINITY => Ok(f64::INFINITY),
NEG_INFINITY => Ok(f64::NEG_INFINITY),
NAN => Ok(f64::NAN),
other => other.parse::<f64>(),
}
}
}
#[cfg(test)]
mod test {
use crate::primitive::{Encoder, Parse};
#[test]
fn bool_format() {
assert_eq!(Encoder::from(true).encode(), "true");
assert_eq!(Encoder::from(false).encode(), "false");
let err = bool::parse_smithy_primitive("not a boolean").expect_err("should fail");
assert_eq!(err.0, "bool");
assert_eq!(bool::parse_smithy_primitive("true"), Ok(true));
assert_eq!(bool::parse_smithy_primitive("false"), Ok(false));
}
#[test]
fn float_format() {
assert_eq!(Encoder::from(55_f64).encode(), "55.0");
assert_eq!(Encoder::from(f64::INFINITY).encode(), "Infinity");
assert_eq!(Encoder::from(f32::INFINITY).encode(), "Infinity");
assert_eq!(Encoder::from(f32::NEG_INFINITY).encode(), "-Infinity");
assert_eq!(Encoder::from(f64::NEG_INFINITY).encode(), "-Infinity");
assert_eq!(Encoder::from(f32::NAN).encode(), "NaN");
assert_eq!(Encoder::from(f64::NAN).encode(), "NaN");
}
#[test]
fn float_parse() {
assert_eq!(f64::parse_smithy_primitive("1234.5"), Ok(1234.5));
assert!(f64::parse_smithy_primitive("NaN").unwrap().is_nan());
assert_eq!(
f64::parse_smithy_primitive("Infinity").unwrap(),
f64::INFINITY
);
assert_eq!(
f64::parse_smithy_primitive("-Infinity").unwrap(),
f64::NEG_INFINITY
);
assert_eq!(f32::parse_smithy_primitive("1234.5"), Ok(1234.5));
assert!(f32::parse_smithy_primitive("NaN").unwrap().is_nan());
assert_eq!(
f32::parse_smithy_primitive("Infinity").unwrap(),
f32::INFINITY
);
assert_eq!(
f32::parse_smithy_primitive("-Infinity").unwrap(),
f32::NEG_INFINITY
);
}
}