mirror of https://github.com/smithy-lang/smithy-rs
Initial awsJson 1.0 and 1.1 server implementation (#1279)
To support these new protocols, we also changed the runtime Router definition as for both awsJson 1.0 and 1.1, every request MUST be sent to the root URL (/) using the HTTP "POST" method. The runtime Router now supports explicitly all the available protocols. There are still some protocol tests failing for awsJson 1.0 and 1.1. The failure are caused by the missing implementation of @endpoint trait and date parsing. Protocol tests for these 2 protocols are heavily biased towards Responses. Before announcing support for awsJson 1.0 and 1.1, we should increase the protocol tests coverage. Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> Co-authored-by: david-perez <d@vidp.dev>
This commit is contained in:
parent
a870d2ad05
commit
2931c9e1e6
|
@ -36,6 +36,8 @@ val allCodegenTests = listOf(
|
||||||
CodegenTest("com.amazonaws.simple#SimpleService", "simple"),
|
CodegenTest("com.amazonaws.simple#SimpleService", "simple"),
|
||||||
CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
|
CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
|
||||||
CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"),
|
CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"),
|
||||||
|
CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
|
||||||
|
CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"),
|
||||||
CodegenTest("aws.protocoltests.misc#MiscService", "misc"),
|
CodegenTest("aws.protocoltests.misc#MiscService", "misc"),
|
||||||
CodegenTest("com.amazonaws.ebs#Ebs", "ebs"),
|
CodegenTest("com.amazonaws.ebs#Ebs", "ebs"),
|
||||||
CodegenTest("com.amazonaws.s3#AmazonS3", "s3"),
|
CodegenTest("com.amazonaws.s3#AmazonS3", "s3"),
|
||||||
|
|
|
@ -38,4 +38,7 @@ object ServerRuntimeType {
|
||||||
|
|
||||||
fun ResponseRejection(runtimeConfig: RuntimeConfig) =
|
fun ResponseRejection(runtimeConfig: RuntimeConfig) =
|
||||||
RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection")
|
RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection")
|
||||||
|
|
||||||
|
fun Protocol(runtimeConfig: RuntimeConfig) =
|
||||||
|
RuntimeType("Protocol", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols")
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,10 @@
|
||||||
|
|
||||||
package software.amazon.smithy.rust.codegen.server.smithy.generators
|
package software.amazon.smithy.rust.codegen.server.smithy.generators
|
||||||
|
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
||||||
import software.amazon.smithy.model.shapes.OperationShape
|
import software.amazon.smithy.model.shapes.OperationShape
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.Attribute
|
import software.amazon.smithy.rust.codegen.rustlang.Attribute
|
||||||
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
|
||||||
|
@ -28,7 +32,9 @@ class ServerOperationRegistryGenerator(
|
||||||
private val httpBindingResolver: HttpBindingResolver,
|
private val httpBindingResolver: HttpBindingResolver,
|
||||||
private val operations: List<OperationShape>,
|
private val operations: List<OperationShape>,
|
||||||
) {
|
) {
|
||||||
|
private val protocol = codegenContext.protocol
|
||||||
private val symbolProvider = codegenContext.symbolProvider
|
private val symbolProvider = codegenContext.symbolProvider
|
||||||
|
private val serviceName = codegenContext.serviceShape.toShapeId().name
|
||||||
private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
|
private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
|
||||||
private val runtimeConfig = codegenContext.runtimeConfig
|
private val runtimeConfig = codegenContext.runtimeConfig
|
||||||
private val codegenScope = arrayOf(
|
private val codegenScope = arrayOf(
|
||||||
|
@ -223,7 +229,7 @@ class ServerOperationRegistryGenerator(
|
||||||
rustTemplate(
|
rustTemplate(
|
||||||
"""
|
"""
|
||||||
$requestSpecs
|
$requestSpecs
|
||||||
#{Router}::from_box_clone_service_iter($towerServices)
|
#{Router}::${runtimeRouterConstructor()}($towerServices)
|
||||||
""".trimIndent(),
|
""".trimIndent(),
|
||||||
*codegenScope
|
*codegenScope
|
||||||
)
|
)
|
||||||
|
@ -241,12 +247,42 @@ class ServerOperationRegistryGenerator(
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Generate the `RequestSpec`s for an operation based on its HTTP-bound route.
|
* Finds the runtime function to construct a new `Router` based on the Protocol.
|
||||||
*/
|
*/
|
||||||
private fun OperationShape.requestSpec(): String {
|
private fun runtimeRouterConstructor(): String =
|
||||||
|
when (protocol) {
|
||||||
|
RestJson1Trait.ID -> "new_rest_json_router"
|
||||||
|
RestXmlTrait.ID -> "new_rest_xml_router"
|
||||||
|
AwsJson1_0Trait.ID -> "new_aws_json_10_router"
|
||||||
|
AwsJson1_1Trait.ID -> "new_aws_json_11_router"
|
||||||
|
else -> TODO("Protocol $protocol not supported yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Returns the `RequestSpec`s for an operation based on its HTTP-bound route.
|
||||||
|
*/
|
||||||
|
private fun OperationShape.requestSpec(): String =
|
||||||
|
when (protocol) {
|
||||||
|
RestJson1Trait.ID, RestXmlTrait.ID -> restRequestSpec()
|
||||||
|
AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> awsJsonOperationName()
|
||||||
|
else -> TODO("Protocol $protocol not supported yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Returns an AwsJson specific runtime `RequestSpec`.
|
||||||
|
*/
|
||||||
|
private fun OperationShape.awsJsonOperationName(): String {
|
||||||
|
val operationName = symbolProvider.toSymbol(this).name
|
||||||
|
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
|
||||||
|
return """String::from("$serviceName.$operationName")"""
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Generates a REST (RestJson1, RestXml) specific runtime `RequestSpec`.
|
||||||
|
*/
|
||||||
|
private fun OperationShape.restRequestSpec(): String {
|
||||||
val httpTrait = httpBindingResolver.httpTrait(this)
|
val httpTrait = httpBindingResolver.httpTrait(this)
|
||||||
val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName()
|
val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName()
|
||||||
|
|
||||||
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait.
|
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait.
|
||||||
val pathSegments = httpTrait.uri.segments.map {
|
val pathSegments = httpTrait.uri.segments.map {
|
||||||
"$namespace::PathSegment::" +
|
"$namespace::PathSegment::" +
|
||||||
|
@ -268,7 +304,7 @@ class ServerOperationRegistryGenerator(
|
||||||
$namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
|
$namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
|
||||||
$namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
|
$namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
""".trimIndent()
|
""".trimIndent()
|
||||||
}
|
}
|
||||||
|
|
|
@ -437,8 +437,8 @@ class ServerProtocolTestGenerator(
|
||||||
else -> {
|
else -> {
|
||||||
rustWriter.rustTemplate(
|
rustWriter.rustTemplate(
|
||||||
"""
|
"""
|
||||||
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
|
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
|
||||||
""",
|
""",
|
||||||
*codegenScope
|
*codegenScope
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -798,6 +798,18 @@ class ServerProtocolTestGenerator(
|
||||||
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request),
|
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request),
|
||||||
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request),
|
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request),
|
||||||
FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request),
|
FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request),
|
||||||
|
|
||||||
|
// AwsJson1.0 failing tests.
|
||||||
|
FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request),
|
||||||
|
FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request),
|
||||||
|
|
||||||
|
// AwsJson1.1 failing tests.
|
||||||
|
FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTraitWithHostLabel", TestType.Request),
|
||||||
|
FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTrait", TestType.Request),
|
||||||
|
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_httpdate_timestamps", TestType.Response),
|
||||||
|
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_iso8601_timestamps", TestType.Response),
|
||||||
|
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_the_request_id_from_the_response", TestType.Response),
|
||||||
|
|
||||||
)
|
)
|
||||||
private val RunOnly: Set<String>? = null
|
private val RunOnly: Set<String>? = null
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
/*
|
||||||
|
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
||||||
|
|
||||||
|
import software.amazon.smithy.model.Model
|
||||||
|
import software.amazon.smithy.model.shapes.OperationShape
|
||||||
|
import software.amazon.smithy.model.traits.ErrorTrait
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.Writable
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.escape
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.rust
|
||||||
|
import software.amazon.smithy.rust.codegen.rustlang.writable
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJson
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.awsJsonFieldName
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonCustomization
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSection
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
|
||||||
|
import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||||
|
|
||||||
|
/*
|
||||||
|
* AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator]
|
||||||
|
* with AwsJson specific configurations.
|
||||||
|
*/
|
||||||
|
class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
|
||||||
|
override fun protocol(codegenContext: CodegenContext): Protocol = ServerAwsJson(codegenContext, version)
|
||||||
|
|
||||||
|
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
|
||||||
|
ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))
|
||||||
|
|
||||||
|
override fun transformModel(model: Model): Model = model
|
||||||
|
|
||||||
|
override fun support(): ProtocolSupport {
|
||||||
|
return ProtocolSupport(
|
||||||
|
/* Client support */
|
||||||
|
requestSerialization = false,
|
||||||
|
requestBodySerialization = false,
|
||||||
|
responseDeserialization = false,
|
||||||
|
errorDeserialization = false,
|
||||||
|
/* Server support */
|
||||||
|
requestDeserialization = true,
|
||||||
|
requestBodyDeserialization = true,
|
||||||
|
responseSerialization = true,
|
||||||
|
errorSerialization = true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AwsJson requires errors to be serialized with an additional "__type" field. This
|
||||||
|
* customization writes the right field depending on the version of the AwsJson protocol.
|
||||||
|
*/
|
||||||
|
class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCustomization() {
|
||||||
|
override fun section(section: JsonSection): Writable = when (section) {
|
||||||
|
is JsonSection.ServerError -> writable {
|
||||||
|
if (section.structureShape.hasTrait<ErrorTrait>()) {
|
||||||
|
val typeId = when (awsJsonVersion) {
|
||||||
|
// AwsJson 1.0 wants the whole shape ID (namespace#Shape).
|
||||||
|
// https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization
|
||||||
|
AwsJsonVersion.Json10 -> section.structureShape.id.toString()
|
||||||
|
// AwsJson 1.1 wants only the shape name (Shape).
|
||||||
|
// https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#operation-error-serialization
|
||||||
|
AwsJsonVersion.Json11 -> section.structureShape.id.name.toString()
|
||||||
|
}
|
||||||
|
rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AwsJson requires errors to be serialized with an additional "__type" field. This class
|
||||||
|
* customizes [JsonSerializerGenerator] to add this functionality.
|
||||||
|
*/
|
||||||
|
class ServerAwsJsonSerializerGenerator(
|
||||||
|
private val codegenContext: CodegenContext,
|
||||||
|
private val httpBindingResolver: HttpBindingResolver,
|
||||||
|
private val awsJsonVersion: AwsJsonVersion,
|
||||||
|
private val jsonSerializerGenerator: JsonSerializerGenerator =
|
||||||
|
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName, customizations = listOf(ServerAwsJsonError(awsJsonVersion)))
|
||||||
|
) : StructuredDataSerializerGenerator by jsonSerializerGenerator
|
||||||
|
|
||||||
|
class ServerAwsJson(
|
||||||
|
private val codegenContext: CodegenContext,
|
||||||
|
private val awsJsonVersion: AwsJsonVersion
|
||||||
|
) : AwsJson(codegenContext, awsJsonVersion) {
|
||||||
|
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
|
||||||
|
ServerAwsJsonSerializerGenerator(codegenContext, httpBindingResolver, awsJsonVersion)
|
||||||
|
}
|
|
@ -5,6 +5,8 @@
|
||||||
|
|
||||||
package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
||||||
|
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
|
||||||
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
||||||
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
||||||
import software.amazon.smithy.codegen.core.Symbol
|
import software.amazon.smithy.codegen.core.Symbol
|
||||||
|
@ -1064,10 +1066,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
|
||||||
private fun getContentTypeCheck(): String {
|
private fun getContentTypeCheck(): String {
|
||||||
when (codegenContext.protocol) {
|
when (codegenContext.protocol) {
|
||||||
RestJson1Trait.ID -> {
|
RestJson1Trait.ID -> {
|
||||||
return "check_json_content_type"
|
return "check_rest_json_1_content_type"
|
||||||
}
|
}
|
||||||
RestXmlTrait.ID -> {
|
RestXmlTrait.ID -> {
|
||||||
return "check_xml_content_type"
|
return "check_rest_xml_content_type"
|
||||||
|
}
|
||||||
|
AwsJson1_0Trait.ID -> {
|
||||||
|
return "check_aws_json_10_content_type"
|
||||||
|
}
|
||||||
|
AwsJson1_1Trait.ID -> {
|
||||||
|
return "check_aws_json_11_content_type"
|
||||||
}
|
}
|
||||||
else -> {
|
else -> {
|
||||||
TODO("Protocol ${codegenContext.protocol} not supported yet")
|
TODO("Protocol ${codegenContext.protocol} not supported yet")
|
||||||
|
@ -1086,7 +1094,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
|
||||||
return ServerRuntimeType.RequestRejection(runtimeConfig)
|
return ServerRuntimeType.RequestRejection(runtimeConfig)
|
||||||
}
|
}
|
||||||
when (codegenContext.protocol) {
|
when (codegenContext.protocol) {
|
||||||
RestJson1Trait.ID -> {
|
RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
|
||||||
return CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error")
|
return CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error")
|
||||||
}
|
}
|
||||||
RestXmlTrait.ID -> {
|
RestXmlTrait.ID -> {
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
|
|
||||||
package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
package software.amazon.smithy.rust.codegen.server.smithy.protocols
|
||||||
|
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
|
||||||
|
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
|
||||||
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
|
||||||
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
|
||||||
import software.amazon.smithy.codegen.core.CodegenException
|
import software.amazon.smithy.codegen.core.CodegenException
|
||||||
|
@ -14,6 +16,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
|
||||||
import software.amazon.smithy.model.shapes.ShapeId
|
import software.amazon.smithy.model.shapes.ShapeId
|
||||||
import software.amazon.smithy.model.traits.Trait
|
import software.amazon.smithy.model.traits.Trait
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
|
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
|
||||||
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
|
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
|
||||||
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
|
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
|
||||||
|
|
||||||
|
@ -39,6 +42,8 @@ class ServerProtocolLoader(private val supportedProtocols: ProtocolMap) {
|
||||||
val DefaultProtocols = mapOf(
|
val DefaultProtocols = mapOf(
|
||||||
RestJson1Trait.ID to ServerRestJsonFactory(),
|
RestJson1Trait.ID to ServerRestJsonFactory(),
|
||||||
RestXmlTrait.ID to ServerRestXmlFactory(),
|
RestXmlTrait.ID to ServerRestXmlFactory(),
|
||||||
|
AwsJson1_0Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json10),
|
||||||
|
AwsJson1_1Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json11),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml
|
||||||
* RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
|
* RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
|
||||||
* with RestXml specific configurations.
|
* with RestXml specific configurations.
|
||||||
*/
|
*/
|
||||||
class ServerRestXmlFactory(private val generator: (CodegenContext) -> Protocol = { RestXml(it) }) :
|
class ServerRestXmlFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
|
||||||
ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
|
override fun protocol(codegenContext: CodegenContext): Protocol = RestXml(codegenContext)
|
||||||
override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext)
|
|
||||||
|
|
||||||
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
|
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
|
||||||
ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext))
|
ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext))
|
||||||
|
|
|
@ -128,7 +128,7 @@ class AwsJsonSerializerGenerator(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class AwsJson(
|
open class AwsJson(
|
||||||
private val codegenContext: CodegenContext,
|
private val codegenContext: CodegenContext,
|
||||||
awsJsonVersion: AwsJsonVersion
|
awsJsonVersion: AwsJsonVersion
|
||||||
) : Protocol {
|
) : Protocol {
|
||||||
|
@ -183,6 +183,4 @@ class AwsJson(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun awsJsonFieldName(member: MemberShape): String {
|
fun awsJsonFieldName(member: MemberShape): String = member.memberName
|
||||||
return member.memberName
|
|
||||||
}
|
|
||||||
|
|
|
@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock
|
||||||
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
|
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
|
||||||
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
|
||||||
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
|
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator
|
||||||
|
import software.amazon.smithy.rust.codegen.smithy.customize.Section
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
|
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant
|
import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant
|
||||||
import software.amazon.smithy.rust.codegen.smithy.generators.serializationError
|
import software.amazon.smithy.rust.codegen.smithy.generators.serializationError
|
||||||
|
@ -49,11 +51,25 @@ import software.amazon.smithy.rust.codegen.util.hasTrait
|
||||||
import software.amazon.smithy.rust.codegen.util.inputShape
|
import software.amazon.smithy.rust.codegen.util.inputShape
|
||||||
import software.amazon.smithy.rust.codegen.util.outputShape
|
import software.amazon.smithy.rust.codegen.util.outputShape
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class describing a JSON section that can be used in a customization.
|
||||||
|
*/
|
||||||
|
sealed class JsonSection(name: String) : Section(name) {
|
||||||
|
/** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */
|
||||||
|
data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSection("ServerError")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JSON customization.
|
||||||
|
*/
|
||||||
|
typealias JsonCustomization = NamedSectionGenerator<JsonSection>
|
||||||
|
|
||||||
class JsonSerializerGenerator(
|
class JsonSerializerGenerator(
|
||||||
codegenContext: CodegenContext,
|
codegenContext: CodegenContext,
|
||||||
private val httpBindingResolver: HttpBindingResolver,
|
private val httpBindingResolver: HttpBindingResolver,
|
||||||
/** Function that maps a MemberShape into a JSON field name */
|
/** Function that maps a MemberShape into a JSON field name */
|
||||||
private val jsonName: (MemberShape) -> String,
|
private val jsonName: (MemberShape) -> String,
|
||||||
|
private val customizations: List<JsonCustomization> = listOf(),
|
||||||
) : StructuredDataSerializerGenerator {
|
) : StructuredDataSerializerGenerator {
|
||||||
private data class Context<T : Shape>(
|
private data class Context<T : Shape>(
|
||||||
/** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */
|
/** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */
|
||||||
|
@ -155,7 +171,7 @@ class JsonSerializerGenerator(
|
||||||
private fun serverStructureSerializer(
|
private fun serverStructureSerializer(
|
||||||
fnName: String,
|
fnName: String,
|
||||||
structureShape: StructureShape,
|
structureShape: StructureShape,
|
||||||
includedMembers: List<MemberShape>
|
includedMembers: List<MemberShape>,
|
||||||
): RuntimeType {
|
): RuntimeType {
|
||||||
return RuntimeType.forInlineFun(fnName, operationSerModule) {
|
return RuntimeType.forInlineFun(fnName, operationSerModule) {
|
||||||
it.rustBlockTemplate(
|
it.rustBlockTemplate(
|
||||||
|
@ -166,6 +182,7 @@ class JsonSerializerGenerator(
|
||||||
rust("let mut out = String::new();")
|
rust("let mut out = String::new();")
|
||||||
rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope)
|
rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope)
|
||||||
serializeStructure(StructContext("object", "value", structureShape), includedMembers)
|
serializeStructure(StructContext("object", "value", structureShape), includedMembers)
|
||||||
|
customizations.forEach { it.section(JsonSection.ServerError(structureShape, "object"))(this) }
|
||||||
rust("object.finish();")
|
rust("object.finish();")
|
||||||
rustTemplate("Ok(out)", *codegenScope)
|
rustTemplate("Ok(out)", *codegenScope)
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ http-body = "0.4"
|
||||||
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
|
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
|
||||||
mime = "0.3"
|
mime = "0.3"
|
||||||
nom = "7"
|
nom = "7"
|
||||||
|
paste = "1"
|
||||||
pin-project-lite = "0.2"
|
pin-project-lite = "0.2"
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
serde_urlencoded = "0.7"
|
serde_urlencoded = "0.7"
|
||||||
|
|
|
@ -6,51 +6,206 @@
|
||||||
//! Protocol helpers.
|
//! Protocol helpers.
|
||||||
use crate::rejection::RequestRejection;
|
use crate::rejection::RequestRejection;
|
||||||
use axum_core::extract::RequestParts;
|
use axum_core::extract::RequestParts;
|
||||||
|
use paste::paste;
|
||||||
|
|
||||||
#[derive(Debug)]
|
/// Supported protocols.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum Protocol {
|
pub enum Protocol {
|
||||||
RestJson1,
|
RestJson1,
|
||||||
RestXml,
|
RestXml,
|
||||||
|
AwsJson10,
|
||||||
|
AwsJson11,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate that the request had the standard JSON content-type header.
|
/// Implement the content-type header validation for a request.
|
||||||
pub fn check_json_content_type<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
|
macro_rules! impl_content_type_validation {
|
||||||
let mime = req
|
($name:literal, $type: literal, $subtype:literal, $rejection:path) => {
|
||||||
.headers()
|
paste! {
|
||||||
.ok_or(RequestRejection::MissingJsonContentType)?
|
#[doc = concat!("Validates that the request has the standard `", $type, "/", $subtype, "` content-type header.")]
|
||||||
.get(http::header::CONTENT_TYPE)
|
pub fn [<check_ $name _content_type>]<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
|
||||||
.ok_or(RequestRejection::MissingJsonContentType)?
|
let mime = req
|
||||||
.to_str()
|
.headers()
|
||||||
.map_err(|_| RequestRejection::MissingJsonContentType)?
|
.ok_or($rejection)?
|
||||||
.parse::<mime::Mime>()
|
.get(http::header::CONTENT_TYPE)
|
||||||
.map_err(|_| RequestRejection::MimeParse)?;
|
.ok_or($rejection)?
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| $rejection)?
|
||||||
|
.parse::<mime::Mime>()
|
||||||
|
.map_err(|_| RequestRejection::MimeParse)?;
|
||||||
|
if mime.type_() == $type && mime.subtype() == $subtype {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err($rejection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if mime.type_() == "application"
|
impl_content_type_validation!(
|
||||||
&& (mime.subtype() == "json" || mime.suffix().filter(|name| *name == "json").is_some())
|
"rest_json_1",
|
||||||
{
|
"application",
|
||||||
Ok(())
|
"json",
|
||||||
} else {
|
RequestRejection::MissingRestJson1ContentType
|
||||||
Err(RequestRejection::MissingJsonContentType)
|
);
|
||||||
}
|
|
||||||
}
|
impl_content_type_validation!(
|
||||||
|
"rest_xml",
|
||||||
/// Validate that the request had the standard XML content-type header.
|
"application",
|
||||||
pub fn check_xml_content_type<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
|
"xml",
|
||||||
let mime = req
|
RequestRejection::MissingRestXmlContentType
|
||||||
.headers()
|
);
|
||||||
.ok_or(RequestRejection::MissingXmlContentType)?
|
|
||||||
.get(http::header::CONTENT_TYPE)
|
impl_content_type_validation!(
|
||||||
.ok_or(RequestRejection::MissingXmlContentType)?
|
"aws_json_10",
|
||||||
.to_str()
|
"application",
|
||||||
.map_err(|_| RequestRejection::MissingXmlContentType)?
|
"x-amz-json-1.0",
|
||||||
.parse::<mime::Mime>()
|
RequestRejection::MissingAwsJson10ContentType
|
||||||
.map_err(|_| RequestRejection::MimeParse)?;
|
);
|
||||||
|
|
||||||
if mime.type_() == "application"
|
impl_content_type_validation!(
|
||||||
&& (mime.subtype() == "xml" || mime.suffix().filter(|name| *name == "xml").is_some())
|
"aws_json_11",
|
||||||
{
|
"application",
|
||||||
Ok(())
|
"x-amz-json-1.1",
|
||||||
} else {
|
RequestRejection::MissingAwsJson11ContentType
|
||||||
Err(RequestRejection::MissingXmlContentType)
|
);
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use http::Request;
|
||||||
|
|
||||||
|
fn req(content_type: &str) -> RequestParts<&str> {
|
||||||
|
let request = Request::builder()
|
||||||
|
.header("content-type", content_type)
|
||||||
|
.body("")
|
||||||
|
.unwrap();
|
||||||
|
RequestParts::new(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This macro validates the rejection type since we cannot implement `PartialEq`
|
||||||
|
/// for `RequestRejection` as it is based on the crate error type, which uses
|
||||||
|
/// `axum_core::BoxError`.
|
||||||
|
macro_rules! validate_rejection_type {
|
||||||
|
($result:expr, $rejection:path) => {
|
||||||
|
match $result {
|
||||||
|
Ok(()) => panic!("Content-type validation is expected to fail"),
|
||||||
|
Err(e) => match e {
|
||||||
|
$rejection => {}
|
||||||
|
_ => panic!("Error {} should be {}", e.to_string(), stringify!($rejection)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_rest_json_1_content_type() {
|
||||||
|
// Check valid content-type header.
|
||||||
|
let request = req("application/json");
|
||||||
|
assert!(check_rest_json_1_content_type(&request).is_ok());
|
||||||
|
|
||||||
|
// Check invalid content-type header.
|
||||||
|
let invalid = vec![
|
||||||
|
req("application/ajson"),
|
||||||
|
req("application/json1"),
|
||||||
|
req("applicatio/json"),
|
||||||
|
req("application/xml"),
|
||||||
|
req("text/xml"),
|
||||||
|
req("application/x-amz-json-1.0"),
|
||||||
|
req("application/x-amz-json-1.1"),
|
||||||
|
RequestParts::new(Request::builder().body("").unwrap()),
|
||||||
|
];
|
||||||
|
for request in &invalid {
|
||||||
|
validate_rejection_type!(
|
||||||
|
check_rest_json_1_content_type(request),
|
||||||
|
RequestRejection::MissingRestJson1ContentType
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check request with not parsable content-type header.
|
||||||
|
validate_rejection_type!(check_rest_json_1_content_type(&req("123")), RequestRejection::MimeParse);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_rest_xml_content_type() {
|
||||||
|
// Check valid content-type header.
|
||||||
|
let request = req("application/xml");
|
||||||
|
assert!(check_rest_xml_content_type(&request).is_ok());
|
||||||
|
|
||||||
|
// Check invalid content-type header.
|
||||||
|
let invalid = vec![
|
||||||
|
req("application/axml"),
|
||||||
|
req("application/xml1"),
|
||||||
|
req("applicatio/xml"),
|
||||||
|
req("text/xml"),
|
||||||
|
req("application/x-amz-json-1.0"),
|
||||||
|
req("application/x-amz-json-1.1"),
|
||||||
|
RequestParts::new(Request::builder().body("").unwrap()),
|
||||||
|
];
|
||||||
|
for request in &invalid {
|
||||||
|
validate_rejection_type!(
|
||||||
|
check_rest_xml_content_type(request),
|
||||||
|
RequestRejection::MissingRestXmlContentType
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check request with not parsable content-type header.
|
||||||
|
validate_rejection_type!(check_rest_xml_content_type(&req("123")), RequestRejection::MimeParse);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_aws_json_10_content_type() {
|
||||||
|
// Check valid content-type header.
|
||||||
|
let request = req("application/x-amz-json-1.0");
|
||||||
|
assert!(check_aws_json_10_content_type(&request).is_ok());
|
||||||
|
|
||||||
|
// Check invalid content-type header.
|
||||||
|
let invalid = vec![
|
||||||
|
req("application/x-amz-json-1."),
|
||||||
|
req("application/-amz-json-1.0"),
|
||||||
|
req("application/xml"),
|
||||||
|
req("application/json"),
|
||||||
|
req("applicatio/x-amz-json-1.0"),
|
||||||
|
req("text/xml"),
|
||||||
|
req("application/x-amz-json-1.1"),
|
||||||
|
RequestParts::new(Request::builder().body("").unwrap()),
|
||||||
|
];
|
||||||
|
for request in &invalid {
|
||||||
|
validate_rejection_type!(
|
||||||
|
check_aws_json_10_content_type(request),
|
||||||
|
RequestRejection::MissingAwsJson10ContentType
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check request with not parsable content-type header.
|
||||||
|
validate_rejection_type!(check_aws_json_10_content_type(&req("123")), RequestRejection::MimeParse);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_aws_json_11_content_type() {
|
||||||
|
// Check valid content-type header.
|
||||||
|
let request = req("application/x-amz-json-1.1");
|
||||||
|
assert!(check_aws_json_11_content_type(&request).is_ok());
|
||||||
|
|
||||||
|
// Check invalid content-type header.
|
||||||
|
let invalid = vec![
|
||||||
|
req("application/x-amz-json-1."),
|
||||||
|
req("application/-amz-json-1.1"),
|
||||||
|
req("application/xml"),
|
||||||
|
req("application/json"),
|
||||||
|
req("applicatio/x-amz-json-1.1"),
|
||||||
|
req("text/xml"),
|
||||||
|
req("application/x-amz-json-1.0"),
|
||||||
|
RequestParts::new(Request::builder().body("").unwrap()),
|
||||||
|
];
|
||||||
|
for request in &invalid {
|
||||||
|
validate_rejection_type!(
|
||||||
|
check_aws_json_11_content_type(request),
|
||||||
|
RequestRejection::MissingAwsJson11ContentType
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check request with not parsable content-type header.
|
||||||
|
validate_rejection_type!(check_aws_json_11_content_type(&req("123")), RequestRejection::MimeParse);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,8 +134,10 @@ pub enum RequestRejection {
|
||||||
HttpBody(crate::Error),
|
HttpBody(crate::Error),
|
||||||
|
|
||||||
// These are used when checking the `Content-Type` header.
|
// These are used when checking the `Content-Type` header.
|
||||||
MissingJsonContentType,
|
MissingRestJson1ContentType,
|
||||||
MissingXmlContentType,
|
MissingAwsJson10ContentType,
|
||||||
|
MissingAwsJson11ContentType,
|
||||||
|
MissingRestXmlContentType,
|
||||||
MimeParse,
|
MimeParse,
|
||||||
|
|
||||||
/// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to
|
/// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to
|
||||||
|
|
|
@ -7,10 +7,15 @@
|
||||||
//!
|
//!
|
||||||
//! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
|
//! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
|
||||||
|
|
||||||
use self::{future::RouterFuture, request_spec::RequestSpec};
|
use self::future::RouterFuture;
|
||||||
|
use self::request_spec::RequestSpec;
|
||||||
use crate::body::{boxed, Body, BoxBody, HttpBody};
|
use crate::body::{boxed, Body, BoxBody, HttpBody};
|
||||||
|
use crate::protocols::Protocol;
|
||||||
|
use crate::runtime_error::{RuntimeError, RuntimeErrorKind};
|
||||||
use crate::BoxError;
|
use crate::BoxError;
|
||||||
|
use axum_core::response::IntoResponse;
|
||||||
use http::{Request, Response, StatusCode};
|
use http::{Request, Response, StatusCode};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
@ -31,34 +36,60 @@ mod route;
|
||||||
pub use self::{into_make_service::IntoMakeService, route::Route};
|
pub use self::{into_make_service::IntoMakeService, route::Route};
|
||||||
|
|
||||||
/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s
|
/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s
|
||||||
/// based on the request's URI and HTTP method, adhering to the [Smithy specification].
|
/// based on the request's URI and HTTP method or on some specific header setting the target operation.
|
||||||
|
/// The former is adhering to the [Smithy specification], while the latter is adhering to
|
||||||
|
/// the [AwsJson specification].
|
||||||
|
///
|
||||||
|
/// The router is also [Protocol] aware and currently supports REST based protocols like [restJson1] or [restXml]
|
||||||
|
/// and RPC based protocols like [awsJson1.0] or [awsJson1.1].
|
||||||
/// It currently does not support Smithy's [endpoint trait].
|
/// It currently does not support Smithy's [endpoint trait].
|
||||||
///
|
///
|
||||||
/// You should not **instantiate** this router directly; it will be created for you from the
|
/// You should not **instantiate** this router directly; it will be created for you from the
|
||||||
/// code generated from your Smithy model by `smithy-rs`.
|
/// code generated from your Smithy model by `smithy-rs`.
|
||||||
///
|
///
|
||||||
/// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
|
/// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
|
||||||
|
/// [AwsJson specification]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#protocol-behaviors
|
||||||
|
/// [Protocol]: https://awslabs.github.io/smithy/1.0/spec/aws/index.html#aws-protocols
|
||||||
|
/// [restJson1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html
|
||||||
|
/// [restXml]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html
|
||||||
|
/// [awsJson1.0]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html
|
||||||
|
/// [awsJson1.1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html
|
||||||
/// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
|
/// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Router<B = Body> {
|
pub struct Router<B = Body> {
|
||||||
routes: Vec<(Route<B>, RequestSpec)>,
|
routes: Routes<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Protocol-aware routes types.
|
||||||
|
///
|
||||||
|
/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the
|
||||||
|
/// request URI and we thus need to iterate the whole list and use a ranking mechanism to choose.
|
||||||
|
///
|
||||||
|
/// AwsJson 1.0 and 1.1 routes can be stored in a `HashMap` since the requested operation can be
|
||||||
|
/// directly found in the `X-Amz-Target` HTTP header.
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum Routes<B = Body> {
|
||||||
|
RestXml(Vec<(Route<B>, RequestSpec)>),
|
||||||
|
RestJson1(Vec<(Route<B>, RequestSpec)>),
|
||||||
|
AwsJson10(HashMap<String, Route<B>>),
|
||||||
|
AwsJson11(HashMap<String, Route<B>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> Clone for Router<B> {
|
impl<B> Clone for Router<B> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
match &self.routes {
|
||||||
routes: self.routes.clone(),
|
Routes::RestJson1(routes) => Router {
|
||||||
}
|
routes: Routes::RestJson1(routes.clone()),
|
||||||
}
|
},
|
||||||
}
|
Routes::RestXml(routes) => Router {
|
||||||
|
routes: Routes::RestXml(routes.clone()),
|
||||||
impl<B> Default for Router<B>
|
},
|
||||||
where
|
Routes::AwsJson10(routes) => Router {
|
||||||
B: Send + 'static,
|
routes: Routes::AwsJson10(routes.clone()),
|
||||||
{
|
},
|
||||||
fn default() -> Self {
|
Routes::AwsJson11(routes) => Router {
|
||||||
Self {
|
routes: Routes::AwsJson11(routes.clone()),
|
||||||
routes: Default::default(),
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -67,32 +98,29 @@ impl<B> Router<B>
|
||||||
where
|
where
|
||||||
B: Send + 'static,
|
B: Send + 'static,
|
||||||
{
|
{
|
||||||
/// Create a new `Router` from a vector of pairs of request specs and services.
|
/// Return the correct, protocol-specific "Not Found" response for an unknown operation.
|
||||||
///
|
fn unknown_operation(&self) -> RouterFuture<B> {
|
||||||
/// If the vector is empty the router will respond `404 Not Found` to all requests.
|
let protocol = match &self.routes {
|
||||||
#[doc(hidden)]
|
Routes::RestJson1(_) => Protocol::RestJson1,
|
||||||
pub fn from_box_clone_service_iter<T>(routes: T) -> Self
|
Routes::RestXml(_) => Protocol::RestXml,
|
||||||
where
|
Routes::AwsJson10(_) => Protocol::AwsJson10,
|
||||||
T: IntoIterator<
|
Routes::AwsJson11(_) => Protocol::AwsJson11,
|
||||||
Item = (
|
};
|
||||||
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
|
let error = RuntimeError {
|
||||||
RequestSpec,
|
protocol,
|
||||||
),
|
kind: RuntimeErrorKind::UnknownOperation,
|
||||||
>,
|
};
|
||||||
{
|
RouterFuture::from_response(error.into_response())
|
||||||
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
|
|
||||||
.into_iter()
|
|
||||||
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Sort them once by specifity, with the more specific routes sorted before the less
|
|
||||||
// specific ones, so that when routing a request we can simply iterate through the routes
|
|
||||||
// and pick the first one that matches.
|
|
||||||
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
|
|
||||||
|
|
||||||
Self { routes }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the HTTP error response for non allowed method.
|
||||||
|
fn method_not_allowed(&self) -> RouterFuture<B> {
|
||||||
|
RouterFuture::from_response({
|
||||||
|
let mut res = Response::new(crate::body::empty());
|
||||||
|
*res.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
|
||||||
|
res
|
||||||
|
})
|
||||||
|
}
|
||||||
/// Convert this router into a [`MakeService`], that is a [`Service`] whose
|
/// Convert this router into a [`MakeService`], that is a [`Service`] whose
|
||||||
/// response is another service.
|
/// response is another service.
|
||||||
///
|
///
|
||||||
|
@ -124,12 +152,146 @@ where
|
||||||
.layer_fn(Route::new)
|
.layer_fn(Route::new)
|
||||||
.layer(MapResponseBodyLayer::new(boxed))
|
.layer(MapResponseBodyLayer::new(boxed))
|
||||||
.layer(layer);
|
.layer(layer);
|
||||||
let routes = self
|
match self.routes {
|
||||||
.routes
|
Routes::RestJson1(routes) => {
|
||||||
|
let routes = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
|
||||||
|
.collect();
|
||||||
|
Router {
|
||||||
|
routes: Routes::RestJson1(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Routes::RestXml(routes) => {
|
||||||
|
let routes = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
|
||||||
|
.collect();
|
||||||
|
Router {
|
||||||
|
routes: Routes::RestXml(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Routes::AwsJson10(routes) => {
|
||||||
|
let routes = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(operation, route)| (operation, Layer::layer(&layer, route)))
|
||||||
|
.collect();
|
||||||
|
Router {
|
||||||
|
routes: Routes::AwsJson10(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Routes::AwsJson11(routes) => {
|
||||||
|
let routes = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(operation, route)| (operation, Layer::layer(&layer, route)))
|
||||||
|
.collect();
|
||||||
|
Router {
|
||||||
|
routes: Routes::AwsJson11(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new RestJson1 `Router` from an iterator over pairs of [`RequestSpec`]s and services.
|
||||||
|
///
|
||||||
|
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn new_rest_json_router<T>(routes: T) -> Self
|
||||||
|
where
|
||||||
|
T: IntoIterator<
|
||||||
|
Item = (
|
||||||
|
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
|
||||||
|
RequestSpec,
|
||||||
|
),
|
||||||
|
>,
|
||||||
|
{
|
||||||
|
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
|
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
|
||||||
.collect();
|
.collect();
|
||||||
Router { routes }
|
|
||||||
|
// Sort them once by specifity, with the more specific routes sorted before the less
|
||||||
|
// specific ones, so that when routing a request we can simply iterate through the routes
|
||||||
|
// and pick the first one that matches.
|
||||||
|
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
routes: Routes::RestJson1(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new RestXml `Router` from an iterator over pairs of [`RequestSpec`]s and services.
|
||||||
|
///
|
||||||
|
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn new_rest_xml_router<T>(routes: T) -> Self
|
||||||
|
where
|
||||||
|
T: IntoIterator<
|
||||||
|
Item = (
|
||||||
|
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
|
||||||
|
RequestSpec,
|
||||||
|
),
|
||||||
|
>,
|
||||||
|
{
|
||||||
|
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Sort them once by specifity, with the more specific routes sorted before the less
|
||||||
|
// specific ones, so that when routing a request we can simply iterate through the routes
|
||||||
|
// and pick the first one that matches.
|
||||||
|
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
routes: Routes::RestXml(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new AwsJson 1.0 `Router` from an iterator over pairs of operation names and services.
|
||||||
|
///
|
||||||
|
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn new_aws_json_10_router<T>(routes: T) -> Self
|
||||||
|
where
|
||||||
|
T: IntoIterator<
|
||||||
|
Item = (
|
||||||
|
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
|
||||||
|
String,
|
||||||
|
),
|
||||||
|
>,
|
||||||
|
{
|
||||||
|
let routes = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
routes: Routes::AwsJson10(routes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new AwsJson 1.1 `Router` from a vector of pairs of operations and services.
|
||||||
|
///
|
||||||
|
/// If the vector is empty the router will respond `404 Not Found` to all requests.
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn new_aws_json_11_router<T>(routes: T) -> Self
|
||||||
|
where
|
||||||
|
T: IntoIterator<
|
||||||
|
Item = (
|
||||||
|
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
|
||||||
|
String,
|
||||||
|
),
|
||||||
|
>,
|
||||||
|
{
|
||||||
|
let routes = routes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
routes: Routes::AwsJson11(routes),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,44 +310,84 @@ where
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||||
let mut method_not_allowed = false;
|
match &self.routes {
|
||||||
|
// REST routes.
|
||||||
|
Routes::RestJson1(routes) | Routes::RestXml(routes) => {
|
||||||
|
let mut method_not_allowed = false;
|
||||||
|
|
||||||
for (route, request_spec) in &self.routes {
|
// Loop through all the routes and validate if any of them matches. Routes are already ranked.
|
||||||
match request_spec.matches(&req) {
|
for (route, request_spec) in routes {
|
||||||
request_spec::Match::Yes => {
|
match request_spec.matches(&req) {
|
||||||
return RouterFuture::from_oneshot(route.clone().oneshot(req));
|
request_spec::Match::Yes => {
|
||||||
|
return RouterFuture::from_oneshot(route.clone().oneshot(req));
|
||||||
|
}
|
||||||
|
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
|
||||||
|
// Continue looping to see if another route matches.
|
||||||
|
request_spec::Match::No => continue,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
|
|
||||||
// Continue looping to see if another route matches.
|
if method_not_allowed {
|
||||||
request_spec::Match::No => continue,
|
// The HTTP method is not correct.
|
||||||
|
self.method_not_allowed()
|
||||||
|
} else {
|
||||||
|
// In any other case return the `RuntimeError::UnknownOperation`.
|
||||||
|
self.unknown_operation()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// AwsJson routes.
|
||||||
|
Routes::AwsJson10(routes) | Routes::AwsJson11(routes) => {
|
||||||
|
if req.uri() == "/" {
|
||||||
|
// Check the request method for POST.
|
||||||
|
if req.method() == http::Method::POST {
|
||||||
|
// Find the `x-amz-target` header.
|
||||||
|
if let Some(target) = req.headers().get("x-amz-target") {
|
||||||
|
if let Ok(target) = target.to_str() {
|
||||||
|
// Lookup in the `HashMap` for a route for the target.
|
||||||
|
let route = routes.get(target);
|
||||||
|
if let Some(route) = route {
|
||||||
|
return RouterFuture::from_oneshot(route.clone().oneshot(req));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The HTTP method is not POST.
|
||||||
|
return self.method_not_allowed();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// In any other case return the `RuntimeError::UnknownOperation`.
|
||||||
|
self.unknown_operation()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let status_code = if method_not_allowed {
|
|
||||||
StatusCode::METHOD_NOT_ALLOWED
|
|
||||||
} else {
|
|
||||||
StatusCode::NOT_FOUND
|
|
||||||
};
|
|
||||||
RouterFuture::from_response(
|
|
||||||
Response::builder()
|
|
||||||
.status(status_code)
|
|
||||||
.body(crate::body::empty())
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod rest_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{body::boxed, routing::request_spec::*};
|
use crate::{body::boxed, routing::request_spec::*};
|
||||||
use futures_util::Future;
|
use futures_util::Future;
|
||||||
use http::Method;
|
use http::{HeaderMap, Method};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
/// Helper function to build a `Request`. Used in other test modules.
|
/// Helper function to build a `Request`. Used in other test modules.
|
||||||
pub fn req(method: &Method, uri: &str) -> Request<()> {
|
pub fn req(method: &Method, uri: &str, headers: Option<HeaderMap>) -> Request<()> {
|
||||||
Request::builder().method(method).uri(uri).body(()).unwrap()
|
let mut r = Request::builder().method(method).uri(uri).body(()).unwrap();
|
||||||
|
if let Some(headers) = headers {
|
||||||
|
*r.headers_mut() = headers
|
||||||
|
}
|
||||||
|
r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a `Response`'s body as a `String`, without consuming the response.
|
||||||
|
pub async fn get_body_as_string<B>(res: &mut Response<B>) -> String
|
||||||
|
where
|
||||||
|
B: http_body::Body + std::marker::Unpin,
|
||||||
|
B::Error: std::fmt::Debug,
|
||||||
|
{
|
||||||
|
let body_mut = res.body_mut();
|
||||||
|
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
|
||||||
|
String::from(std::str::from_utf8(&body_bytes).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A service that returns its name and the request's URI in the response body.
|
/// A service that returns its name and the request's URI in the response body.
|
||||||
|
@ -210,17 +412,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a `Response`'s body as a `String`, without consuming the response.
|
|
||||||
async fn get_body_as_str<B>(res: &mut Response<B>) -> String
|
|
||||||
where
|
|
||||||
B: http_body::Body + std::marker::Unpin,
|
|
||||||
B::Error: std::fmt::Debug,
|
|
||||||
{
|
|
||||||
let body_mut = res.body_mut();
|
|
||||||
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
|
|
||||||
String::from(std::str::from_utf8(&body_bytes).unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
// This test is a rewrite of `mux.spec.ts`.
|
// This test is a rewrite of `mux.spec.ts`.
|
||||||
// https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts
|
// https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -271,55 +462,64 @@ mod tests {
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| {
|
// Test both RestJson1 and RestXml routers.
|
||||||
|
let router_json = Router::new_rest_json_router(request_specs.clone().into_iter().map(|(spec, svc_name)| {
|
||||||
|
(
|
||||||
|
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
|
||||||
|
spec,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
let router_xml = Router::new_rest_xml_router(request_specs.into_iter().map(|(spec, svc_name)| {
|
||||||
(
|
(
|
||||||
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
|
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
|
||||||
spec,
|
spec,
|
||||||
)
|
)
|
||||||
}));
|
}));
|
||||||
|
|
||||||
let hits = vec![
|
for mut router in [router_json, router_xml] {
|
||||||
("A", Method::GET, "/a/b/c"),
|
let hits = vec![
|
||||||
("MiddleGreedy", Method::GET, "/mg/a/z"),
|
("A", Method::GET, "/a/b/c"),
|
||||||
("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"),
|
("MiddleGreedy", Method::GET, "/mg/a/z"),
|
||||||
("Delete", Method::DELETE, "/?foo=bar&baz=quux"),
|
("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"),
|
||||||
("Delete", Method::DELETE, "/?foo=bar&baz"),
|
("Delete", Method::DELETE, "/?foo=bar&baz=quux"),
|
||||||
("Delete", Method::DELETE, "/?foo=bar&baz=&"),
|
("Delete", Method::DELETE, "/?foo=bar&baz"),
|
||||||
("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"),
|
("Delete", Method::DELETE, "/?foo=bar&baz=&"),
|
||||||
("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"),
|
("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"),
|
||||||
("QueryKeyOnly", Method::POST, "/query_key_only?foo"),
|
("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"),
|
||||||
("QueryKeyOnly", Method::POST, "/query_key_only?foo="),
|
("QueryKeyOnly", Method::POST, "/query_key_only?foo"),
|
||||||
("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"),
|
("QueryKeyOnly", Method::POST, "/query_key_only?foo="),
|
||||||
];
|
("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"),
|
||||||
for (svc_name, method, uri) in &hits {
|
];
|
||||||
let mut res = router.call(req(method, uri)).await.unwrap();
|
for (svc_name, method, uri) in &hits {
|
||||||
let actual_body = get_body_as_str(&mut res).await;
|
let mut res = router.call(req(method, uri, None)).await.unwrap();
|
||||||
|
let actual_body = get_body_as_string(&mut res).await;
|
||||||
|
|
||||||
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
|
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (_, _, uri) in hits {
|
for (_, _, uri) in hits {
|
||||||
let res = router.call(req(&Method::PATCH, uri)).await.unwrap();
|
let res = router.call(req(&Method::PATCH, uri, None)).await.unwrap();
|
||||||
assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status());
|
assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
let misses = vec![
|
let misses = vec![
|
||||||
(Method::GET, "/a"),
|
(Method::GET, "/a"),
|
||||||
(Method::GET, "/a/b"),
|
(Method::GET, "/a/b"),
|
||||||
(Method::GET, "/mg"),
|
(Method::GET, "/mg"),
|
||||||
(Method::GET, "/mg/q"),
|
(Method::GET, "/mg/q"),
|
||||||
(Method::GET, "/mg/z"),
|
(Method::GET, "/mg/z"),
|
||||||
(Method::GET, "/mg/a/b/z/c"),
|
(Method::GET, "/mg/a/b/z/c"),
|
||||||
(Method::DELETE, "/?foo=bar"),
|
(Method::DELETE, "/?foo=bar"),
|
||||||
(Method::DELETE, "/?foo=bar"),
|
(Method::DELETE, "/?foo=bar"),
|
||||||
(Method::DELETE, "/?baz=quux"),
|
(Method::DELETE, "/?baz=quux"),
|
||||||
(Method::POST, "/query_key_only?baz=quux"),
|
(Method::POST, "/query_key_only?baz=quux"),
|
||||||
(Method::GET, "/"),
|
(Method::GET, "/"),
|
||||||
(Method::POST, "/"),
|
(Method::POST, "/"),
|
||||||
];
|
];
|
||||||
for (method, miss) in misses {
|
for (method, miss) in misses {
|
||||||
let res = router.call(req(&method, miss)).await.unwrap();
|
let res = router.call(req(&method, miss, None)).await.unwrap();
|
||||||
assert_eq!(StatusCode::NOT_FOUND, res.status());
|
assert_eq!(StatusCode::NOT_FOUND, res.status());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,7 +564,7 @@ mod tests {
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| {
|
let mut router = Router::new_rest_json_router(request_specs.into_iter().map(|(spec, svc_name)| {
|
||||||
(
|
(
|
||||||
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
|
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
|
||||||
spec,
|
spec,
|
||||||
|
@ -378,10 +578,96 @@ mod tests {
|
||||||
("B2", Method::GET, "/b/foo?q=baz"),
|
("B2", Method::GET, "/b/foo?q=baz"),
|
||||||
];
|
];
|
||||||
for (svc_name, method, uri) in &hits {
|
for (svc_name, method, uri) in &hits {
|
||||||
let mut res = router.call(req(method, uri)).await.unwrap();
|
let mut res = router.call(req(method, uri, None)).await.unwrap();
|
||||||
let actual_body = get_body_as_str(&mut res).await;
|
let actual_body = get_body_as_string(&mut res).await;
|
||||||
|
|
||||||
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
|
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod awsjson_tests {
|
||||||
|
use super::rest_tests::{get_body_as_string, req};
|
||||||
|
use super::*;
|
||||||
|
use crate::body::boxed;
|
||||||
|
use futures_util::Future;
|
||||||
|
use http::{HeaderMap, HeaderValue, Method};
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
/// A service that returns its name and the request's URI in the response body.
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct NamedEchoOperationService(String);
|
||||||
|
|
||||||
|
impl<B> Service<Request<B>> for NamedEchoOperationService {
|
||||||
|
type Response = Response<BoxBody>;
|
||||||
|
type Error = Infallible;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||||
|
let target = req
|
||||||
|
.headers()
|
||||||
|
.get("x-amz-target")
|
||||||
|
.map(|x| x.to_str().unwrap())
|
||||||
|
.unwrap_or("unknown");
|
||||||
|
let body = boxed(Body::from(format!("{} :: {}", self.0, target)));
|
||||||
|
let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) };
|
||||||
|
Box::pin(fut)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn simple_routing() {
|
||||||
|
let routes = vec![("Service.Operation", "A")];
|
||||||
|
let router_json10 = Router::new_aws_json_10_router(routes.clone().into_iter().map(|(operation, svc_name)| {
|
||||||
|
(
|
||||||
|
tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))),
|
||||||
|
operation.to_string(),
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
let router_json11 = Router::new_aws_json_11_router(routes.into_iter().map(|(operation, svc_name)| {
|
||||||
|
(
|
||||||
|
tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))),
|
||||||
|
operation.to_string(),
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
|
||||||
|
for mut router in [router_json10, router_json11] {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation"));
|
||||||
|
|
||||||
|
// Valid request, should return a valid body.
|
||||||
|
let mut res = router
|
||||||
|
.call(req(&Method::POST, "/", Some(headers.clone())))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let actual_body = get_body_as_string(&mut res).await;
|
||||||
|
assert_eq!(format!("{} :: {}", "A", "Service.Operation"), actual_body);
|
||||||
|
|
||||||
|
// No headers, should return NOT_FOUND.
|
||||||
|
let res = router.call(req(&Method::POST, "/", None)).await.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||||
|
|
||||||
|
// Wrong HTTP method, should return METHOD_NOT_ALLOWED.
|
||||||
|
let res = router
|
||||||
|
.call(req(&Method::GET, "/", Some(headers.clone())))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
|
||||||
|
// Wrong URI, should return NOT_FOUND.
|
||||||
|
let res = router
|
||||||
|
.call(req(&Method::POST, "/something", Some(headers)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -243,7 +243,7 @@ impl RequestSpec {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::super::tests::req;
|
use super::super::rest_tests::req;
|
||||||
use super::*;
|
use super::*;
|
||||||
use http::Method;
|
use http::Method;
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ mod tests {
|
||||||
(Method::GET, "/mg/a/z/z/z"),
|
(Method::GET, "/mg/a/z/z/z"),
|
||||||
];
|
];
|
||||||
for (method, uri) in &hits {
|
for (method, uri) in &hits {
|
||||||
assert_eq!(Match::Yes, spec.matches(&req(method, uri)));
|
assert_eq!(Match::Yes, spec.matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,7 +309,7 @@ mod tests {
|
||||||
(Method::DELETE, "/?foo&foo"),
|
(Method::DELETE, "/?foo&foo"),
|
||||||
];
|
];
|
||||||
for (method, uri) in &hits {
|
for (method, uri) in &hits {
|
||||||
assert_eq!(Match::Yes, spec.matches(&req(method, uri)));
|
assert_eq!(Match::Yes, spec.matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -325,7 +325,7 @@ mod tests {
|
||||||
fn repeated_query_keys_same_values_match() {
|
fn repeated_query_keys_same_values_match() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Match::Yes,
|
Match::Yes,
|
||||||
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar"))
|
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar", None))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,7 +333,7 @@ mod tests {
|
||||||
fn repeated_query_keys_distinct_values_does_not_match() {
|
fn repeated_query_keys_distinct_values_does_not_match() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Match::No,
|
Match::No,
|
||||||
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz"))
|
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz", None))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,11 +354,11 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_segments_in_the_middle_do_matter() {
|
fn empty_segments_in_the_middle_do_matter() {
|
||||||
assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b")));
|
assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b", None)));
|
||||||
|
|
||||||
let misses = vec![(Method::GET, "/a//b"), (Method::GET, "//////a//b")];
|
let misses = vec![(Method::GET, "/a//b"), (Method::GET, "//////a//b")];
|
||||||
for (method, uri) in &misses {
|
for (method, uri) in &misses {
|
||||||
assert_eq!(Match::No, ab_spec().matches(&req(method, uri)));
|
assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,10 +379,10 @@ mod tests {
|
||||||
(Method::GET, "/a//b"), // Label is bound to `""`.
|
(Method::GET, "/a//b"), // Label is bound to `""`.
|
||||||
];
|
];
|
||||||
for (method, uri) in &hits {
|
for (method, uri) in &hits {
|
||||||
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri)));
|
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b")));
|
assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b", None)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -403,7 +403,7 @@ mod tests {
|
||||||
(Method::GET, "/a///a//b///suffix"),
|
(Method::GET, "/a///a//b///suffix"),
|
||||||
];
|
];
|
||||||
for (method, uri) in &hits {
|
for (method, uri) in &hits {
|
||||||
assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri)));
|
assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -418,7 +418,7 @@ mod tests {
|
||||||
(Method::GET, "//a//b////"),
|
(Method::GET, "//a//b////"),
|
||||||
];
|
];
|
||||||
for (method, uri) in &misses {
|
for (method, uri) in &misses {
|
||||||
assert_eq!(Match::No, ab_spec().matches(&req(method, uri)));
|
assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -432,13 +432,13 @@ mod tests {
|
||||||
|
|
||||||
let misses = vec![(Method::GET, "/a"), (Method::GET, "/a//"), (Method::GET, "/a///")];
|
let misses = vec![(Method::GET, "/a"), (Method::GET, "/a//"), (Method::GET, "/a///")];
|
||||||
for (method, uri) in &misses {
|
for (method, uri) in &misses {
|
||||||
assert_eq!(Match::No, label_spec.matches(&req(method, uri)));
|
assert_eq!(Match::No, label_spec.matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// In the second example, the label is bound to `""`.
|
// In the second example, the label is bound to `""`.
|
||||||
let hits = vec![(Method::GET, "/a/label"), (Method::GET, "/a/")];
|
let hits = vec![(Method::GET, "/a/label"), (Method::GET, "/a/")];
|
||||||
for (method, uri) in &hits {
|
for (method, uri) in &hits {
|
||||||
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri)));
|
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,8 @@ use crate::protocols::Protocol;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum RuntimeErrorKind {
|
pub enum RuntimeErrorKind {
|
||||||
// UnknownOperation,
|
/// The requested operation does not exist.
|
||||||
|
UnknownOperation,
|
||||||
/// Request failed to deserialize or response failed to serialize.
|
/// Request failed to deserialize or response failed to serialize.
|
||||||
Serialization(crate::Error),
|
Serialization(crate::Error),
|
||||||
/// As of writing, this variant can only occur upon failure to extract an
|
/// As of writing, this variant can only occur upon failure to extract an
|
||||||
|
@ -43,6 +44,7 @@ impl RuntimeErrorKind {
|
||||||
match self {
|
match self {
|
||||||
RuntimeErrorKind::Serialization(_) => "SerializationException",
|
RuntimeErrorKind::Serialization(_) => "SerializationException",
|
||||||
RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
|
RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
|
||||||
|
RuntimeErrorKind::UnknownOperation => "UnknownOperation",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -58,11 +60,16 @@ impl axum_core::response::IntoResponse for RuntimeError {
|
||||||
let status_code = match self.kind {
|
let status_code = match self.kind {
|
||||||
RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
|
RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
|
||||||
RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
|
RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
RuntimeErrorKind::UnknownOperation => http::StatusCode::NOT_FOUND,
|
||||||
};
|
};
|
||||||
|
|
||||||
let body = crate::body::to_boxed(match self.protocol {
|
let body = crate::body::to_boxed(match self.protocol {
|
||||||
Protocol::RestJson1 => "{}",
|
Protocol::RestJson1 => "{}",
|
||||||
Protocol::RestXml => "",
|
Protocol::RestXml => "",
|
||||||
|
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
|
||||||
|
Protocol::AwsJson10 => "",
|
||||||
|
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
|
||||||
|
Protocol::AwsJson11 => "",
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut builder = http::Response::builder();
|
let mut builder = http::Response::builder();
|
||||||
|
@ -74,9 +81,9 @@ impl axum_core::response::IntoResponse for RuntimeError {
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("X-Amzn-Errortype", self.kind.name());
|
.header("X-Amzn-Errortype", self.kind.name());
|
||||||
}
|
}
|
||||||
Protocol::RestXml => {
|
Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"),
|
||||||
builder = builder.header("Content-Type", "application/xml");
|
Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"),
|
||||||
}
|
Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from(
|
builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from(
|
||||||
|
|
Loading…
Reference in New Issue