Initial awsJson 1.0 and 1.1 server implementation (#1279)

To support these new protocols, we also changed the runtime Router definition as for both awsJson 1.0 and 1.1, every request MUST be sent to the root URL (/) using the HTTP "POST" method.

The runtime Router now supports explicitly all the available protocols.

There are still some protocol tests failing for awsJson 1.0 and 1.1. The failure are caused by the missing implementation of @endpoint trait and date parsing. Protocol tests for these 2 protocols are heavily biased towards Responses. Before announcing support for awsJson 1.0 and 1.1, we should increase the protocol tests coverage.

Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: david-perez <d@vidp.dev>
This commit is contained in:
Matteo Bigoi 2022-04-19 17:41:03 +01:00 committed by GitHub
parent a870d2ad05
commit 2931c9e1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 826 additions and 198 deletions

View File

@ -36,6 +36,8 @@ val allCodegenTests = listOf(
CodegenTest("com.amazonaws.simple#SimpleService", "simple"), CodegenTest("com.amazonaws.simple#SimpleService", "simple"),
CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"), CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"), CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"),
CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"),
CodegenTest("aws.protocoltests.misc#MiscService", "misc"), CodegenTest("aws.protocoltests.misc#MiscService", "misc"),
CodegenTest("com.amazonaws.ebs#Ebs", "ebs"), CodegenTest("com.amazonaws.ebs#Ebs", "ebs"),
CodegenTest("com.amazonaws.s3#AmazonS3", "s3"), CodegenTest("com.amazonaws.s3#AmazonS3", "s3"),

View File

@ -38,4 +38,7 @@ object ServerRuntimeType {
fun ResponseRejection(runtimeConfig: RuntimeConfig) = fun ResponseRejection(runtimeConfig: RuntimeConfig) =
RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection") RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection")
fun Protocol(runtimeConfig: RuntimeConfig) =
RuntimeType("Protocol", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols")
} }

View File

@ -5,6 +5,10 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators package software.amazon.smithy.rust.codegen.server.smithy.generators
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.RustWriter
@ -28,7 +32,9 @@ class ServerOperationRegistryGenerator(
private val httpBindingResolver: HttpBindingResolver, private val httpBindingResolver: HttpBindingResolver,
private val operations: List<OperationShape>, private val operations: List<OperationShape>,
) { ) {
private val protocol = codegenContext.protocol
private val symbolProvider = codegenContext.symbolProvider private val symbolProvider = codegenContext.symbolProvider
private val serviceName = codegenContext.serviceShape.toShapeId().name
private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() } private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
private val runtimeConfig = codegenContext.runtimeConfig private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf( private val codegenScope = arrayOf(
@ -223,7 +229,7 @@ class ServerOperationRegistryGenerator(
rustTemplate( rustTemplate(
""" """
$requestSpecs $requestSpecs
#{Router}::from_box_clone_service_iter($towerServices) #{Router}::${runtimeRouterConstructor()}($towerServices)
""".trimIndent(), """.trimIndent(),
*codegenScope *codegenScope
) )
@ -241,12 +247,42 @@ class ServerOperationRegistryGenerator(
} }
/* /*
* Generate the `RequestSpec`s for an operation based on its HTTP-bound route. * Finds the runtime function to construct a new `Router` based on the Protocol.
*/ */
private fun OperationShape.requestSpec(): String { private fun runtimeRouterConstructor(): String =
when (protocol) {
RestJson1Trait.ID -> "new_rest_json_router"
RestXmlTrait.ID -> "new_rest_xml_router"
AwsJson1_0Trait.ID -> "new_aws_json_10_router"
AwsJson1_1Trait.ID -> "new_aws_json_11_router"
else -> TODO("Protocol $protocol not supported yet")
}
/*
* Returns the `RequestSpec`s for an operation based on its HTTP-bound route.
*/
private fun OperationShape.requestSpec(): String =
when (protocol) {
RestJson1Trait.ID, RestXmlTrait.ID -> restRequestSpec()
AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> awsJsonOperationName()
else -> TODO("Protocol $protocol not supported yet")
}
/*
* Returns an AwsJson specific runtime `RequestSpec`.
*/
private fun OperationShape.awsJsonOperationName(): String {
val operationName = symbolProvider.toSymbol(this).name
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
return """String::from("$serviceName.$operationName")"""
}
/*
* Generates a REST (RestJson1, RestXml) specific runtime `RequestSpec`.
*/
private fun OperationShape.restRequestSpec(): String {
val httpTrait = httpBindingResolver.httpTrait(this) val httpTrait = httpBindingResolver.httpTrait(this)
val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName() val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName()
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait. // TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait.
val pathSegments = httpTrait.uri.segments.map { val pathSegments = httpTrait.uri.segments.map {
"$namespace::PathSegment::" + "$namespace::PathSegment::" +
@ -268,7 +304,7 @@ class ServerOperationRegistryGenerator(
$namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]), $namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
$namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}]) $namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
) )
) ),
) )
""".trimIndent() """.trimIndent()
} }

View File

@ -437,8 +437,8 @@ class ServerProtocolTestGenerator(
else -> { else -> {
rustWriter.rustTemplate( rustWriter.rustTemplate(
""" """
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""", """,
*codegenScope *codegenScope
) )
} }
@ -798,6 +798,18 @@ class ServerProtocolTestGenerator(
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request), FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request),
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request), FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request),
FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request), FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request),
// AwsJson1.0 failing tests.
FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request),
FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request),
// AwsJson1.1 failing tests.
FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTraitWithHostLabel", TestType.Request),
FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTrait", TestType.Request),
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_httpdate_timestamps", TestType.Response),
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_iso8601_timestamps", TestType.Response),
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_the_request_id_from_the_response", TestType.Response),
) )
private val RunOnly: Set<String>? = null private val RunOnly: Set<String>? = null

View File

@ -0,0 +1,97 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.escape
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.awsJsonFieldName
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonCustomization
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSection
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.hasTrait
/*
* AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator]
* with AwsJson specific configurations.
*/
class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = ServerAwsJson(codegenContext, version)
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(
/* Client support */
requestSerialization = false,
requestBodySerialization = false,
responseDeserialization = false,
errorDeserialization = false,
/* Server support */
requestDeserialization = true,
requestBodyDeserialization = true,
responseSerialization = true,
errorSerialization = true
)
}
}
/**
* AwsJson requires errors to be serialized with an additional "__type" field. This
* customization writes the right field depending on the version of the AwsJson protocol.
*/
class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCustomization() {
override fun section(section: JsonSection): Writable = when (section) {
is JsonSection.ServerError -> writable {
if (section.structureShape.hasTrait<ErrorTrait>()) {
val typeId = when (awsJsonVersion) {
// AwsJson 1.0 wants the whole shape ID (namespace#Shape).
// https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization
AwsJsonVersion.Json10 -> section.structureShape.id.toString()
// AwsJson 1.1 wants only the shape name (Shape).
// https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#operation-error-serialization
AwsJsonVersion.Json11 -> section.structureShape.id.name.toString()
}
rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""")
}
}
}
}
/**
* AwsJson requires errors to be serialized with an additional "__type" field. This class
* customizes [JsonSerializerGenerator] to add this functionality.
*/
class ServerAwsJsonSerializerGenerator(
private val codegenContext: CodegenContext,
private val httpBindingResolver: HttpBindingResolver,
private val awsJsonVersion: AwsJsonVersion,
private val jsonSerializerGenerator: JsonSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName, customizations = listOf(ServerAwsJsonError(awsJsonVersion)))
) : StructuredDataSerializerGenerator by jsonSerializerGenerator
class ServerAwsJson(
private val codegenContext: CodegenContext,
private val awsJsonVersion: AwsJsonVersion
) : AwsJson(codegenContext, awsJsonVersion) {
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerAwsJsonSerializerGenerator(codegenContext, httpBindingResolver, awsJsonVersion)
}

View File

@ -5,6 +5,8 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.Symbol
@ -1064,10 +1066,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
private fun getContentTypeCheck(): String { private fun getContentTypeCheck(): String {
when (codegenContext.protocol) { when (codegenContext.protocol) {
RestJson1Trait.ID -> { RestJson1Trait.ID -> {
return "check_json_content_type" return "check_rest_json_1_content_type"
} }
RestXmlTrait.ID -> { RestXmlTrait.ID -> {
return "check_xml_content_type" return "check_rest_xml_content_type"
}
AwsJson1_0Trait.ID -> {
return "check_aws_json_10_content_type"
}
AwsJson1_1Trait.ID -> {
return "check_aws_json_11_content_type"
} }
else -> { else -> {
TODO("Protocol ${codegenContext.protocol} not supported yet") TODO("Protocol ${codegenContext.protocol} not supported yet")
@ -1086,7 +1094,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
return ServerRuntimeType.RequestRejection(runtimeConfig) return ServerRuntimeType.RequestRejection(runtimeConfig)
} }
when (codegenContext.protocol) { when (codegenContext.protocol) {
RestJson1Trait.ID -> { RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
return CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error") return CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error")
} }
RestXmlTrait.ID -> { RestXmlTrait.ID -> {

View File

@ -5,6 +5,8 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.CodegenException
@ -14,6 +16,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
@ -39,6 +42,8 @@ class ServerProtocolLoader(private val supportedProtocols: ProtocolMap) {
val DefaultProtocols = mapOf( val DefaultProtocols = mapOf(
RestJson1Trait.ID to ServerRestJsonFactory(), RestJson1Trait.ID to ServerRestJsonFactory(),
RestXmlTrait.ID to ServerRestXmlFactory(), RestXmlTrait.ID to ServerRestXmlFactory(),
AwsJson1_0Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json10),
AwsJson1_1Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json11),
) )
} }
} }

View File

@ -16,9 +16,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml
* RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator] * RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
* with RestXml specific configurations. * with RestXml specific configurations.
*/ */
class ServerRestXmlFactory(private val generator: (CodegenContext) -> Protocol = { RestXml(it) }) : class ServerRestXmlFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> { override fun protocol(codegenContext: CodegenContext): Protocol = RestXml(codegenContext)
override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext)
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator = override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext)) ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext))

View File

@ -128,7 +128,7 @@ class AwsJsonSerializerGenerator(
} }
} }
class AwsJson( open class AwsJson(
private val codegenContext: CodegenContext, private val codegenContext: CodegenContext,
awsJsonVersion: AwsJsonVersion awsJsonVersion: AwsJsonVersion
) : Protocol { ) : Protocol {
@ -183,6 +183,4 @@ class AwsJson(
} }
} }
private fun awsJsonFieldName(member: MemberShape): String { fun awsJsonFieldName(member: MemberShape): String = member.memberName
return member.memberName
}

View File

@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator
import software.amazon.smithy.rust.codegen.smithy.customize.Section
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.smithy.generators.serializationError import software.amazon.smithy.rust.codegen.smithy.generators.serializationError
@ -49,11 +51,25 @@ import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.outputShape
/**
* Class describing a JSON section that can be used in a customization.
*/
sealed class JsonSection(name: String) : Section(name) {
/** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */
data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSection("ServerError")
}
/**
* JSON customization.
*/
typealias JsonCustomization = NamedSectionGenerator<JsonSection>
class JsonSerializerGenerator( class JsonSerializerGenerator(
codegenContext: CodegenContext, codegenContext: CodegenContext,
private val httpBindingResolver: HttpBindingResolver, private val httpBindingResolver: HttpBindingResolver,
/** Function that maps a MemberShape into a JSON field name */ /** Function that maps a MemberShape into a JSON field name */
private val jsonName: (MemberShape) -> String, private val jsonName: (MemberShape) -> String,
private val customizations: List<JsonCustomization> = listOf(),
) : StructuredDataSerializerGenerator { ) : StructuredDataSerializerGenerator {
private data class Context<T : Shape>( private data class Context<T : Shape>(
/** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */
@ -155,7 +171,7 @@ class JsonSerializerGenerator(
private fun serverStructureSerializer( private fun serverStructureSerializer(
fnName: String, fnName: String,
structureShape: StructureShape, structureShape: StructureShape,
includedMembers: List<MemberShape> includedMembers: List<MemberShape>,
): RuntimeType { ): RuntimeType {
return RuntimeType.forInlineFun(fnName, operationSerModule) { return RuntimeType.forInlineFun(fnName, operationSerModule) {
it.rustBlockTemplate( it.rustBlockTemplate(
@ -166,6 +182,7 @@ class JsonSerializerGenerator(
rust("let mut out = String::new();") rust("let mut out = String::new();")
rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope)
serializeStructure(StructContext("object", "value", structureShape), includedMembers) serializeStructure(StructContext("object", "value", structureShape), includedMembers)
customizations.forEach { it.section(JsonSection.ServerError(structureShape, "object"))(this) }
rust("object.finish();") rust("object.finish();")
rustTemplate("Ok(out)", *codegenScope) rustTemplate("Ok(out)", *codegenScope)
} }

View File

@ -29,6 +29,7 @@ http-body = "0.4"
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] } hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
mime = "0.3" mime = "0.3"
nom = "7" nom = "7"
paste = "1"
pin-project-lite = "0.2" pin-project-lite = "0.2"
regex = "1.0" regex = "1.0"
serde_urlencoded = "0.7" serde_urlencoded = "0.7"

View File

@ -6,51 +6,206 @@
//! Protocol helpers. //! Protocol helpers.
use crate::rejection::RequestRejection; use crate::rejection::RequestRejection;
use axum_core::extract::RequestParts; use axum_core::extract::RequestParts;
use paste::paste;
#[derive(Debug)] /// Supported protocols.
#[derive(Debug, Clone, Copy)]
pub enum Protocol { pub enum Protocol {
RestJson1, RestJson1,
RestXml, RestXml,
AwsJson10,
AwsJson11,
} }
/// Validate that the request had the standard JSON content-type header. /// Implement the content-type header validation for a request.
pub fn check_json_content_type<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> { macro_rules! impl_content_type_validation {
let mime = req ($name:literal, $type: literal, $subtype:literal, $rejection:path) => {
.headers() paste! {
.ok_or(RequestRejection::MissingJsonContentType)? #[doc = concat!("Validates that the request has the standard `", $type, "/", $subtype, "` content-type header.")]
.get(http::header::CONTENT_TYPE) pub fn [<check_ $name _content_type>]<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
.ok_or(RequestRejection::MissingJsonContentType)? let mime = req
.to_str() .headers()
.map_err(|_| RequestRejection::MissingJsonContentType)? .ok_or($rejection)?
.parse::<mime::Mime>() .get(http::header::CONTENT_TYPE)
.map_err(|_| RequestRejection::MimeParse)?; .ok_or($rejection)?
.to_str()
.map_err(|_| $rejection)?
.parse::<mime::Mime>()
.map_err(|_| RequestRejection::MimeParse)?;
if mime.type_() == $type && mime.subtype() == $subtype {
Ok(())
} else {
Err($rejection)
}
}
}
};
}
if mime.type_() == "application" impl_content_type_validation!(
&& (mime.subtype() == "json" || mime.suffix().filter(|name| *name == "json").is_some()) "rest_json_1",
{ "application",
Ok(()) "json",
} else { RequestRejection::MissingRestJson1ContentType
Err(RequestRejection::MissingJsonContentType) );
}
} impl_content_type_validation!(
"rest_xml",
/// Validate that the request had the standard XML content-type header. "application",
pub fn check_xml_content_type<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> { "xml",
let mime = req RequestRejection::MissingRestXmlContentType
.headers() );
.ok_or(RequestRejection::MissingXmlContentType)?
.get(http::header::CONTENT_TYPE) impl_content_type_validation!(
.ok_or(RequestRejection::MissingXmlContentType)? "aws_json_10",
.to_str() "application",
.map_err(|_| RequestRejection::MissingXmlContentType)? "x-amz-json-1.0",
.parse::<mime::Mime>() RequestRejection::MissingAwsJson10ContentType
.map_err(|_| RequestRejection::MimeParse)?; );
if mime.type_() == "application" impl_content_type_validation!(
&& (mime.subtype() == "xml" || mime.suffix().filter(|name| *name == "xml").is_some()) "aws_json_11",
{ "application",
Ok(()) "x-amz-json-1.1",
} else { RequestRejection::MissingAwsJson11ContentType
Err(RequestRejection::MissingXmlContentType) );
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
fn req(content_type: &str) -> RequestParts<&str> {
let request = Request::builder()
.header("content-type", content_type)
.body("")
.unwrap();
RequestParts::new(request)
}
/// This macro validates the rejection type since we cannot implement `PartialEq`
/// for `RequestRejection` as it is based on the crate error type, which uses
/// `axum_core::BoxError`.
macro_rules! validate_rejection_type {
($result:expr, $rejection:path) => {
match $result {
Ok(()) => panic!("Content-type validation is expected to fail"),
Err(e) => match e {
$rejection => {}
_ => panic!("Error {} should be {}", e.to_string(), stringify!($rejection)),
},
}
};
}
#[test]
fn validate_rest_json_1_content_type() {
// Check valid content-type header.
let request = req("application/json");
assert!(check_rest_json_1_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/ajson"),
req("application/json1"),
req("applicatio/json"),
req("application/xml"),
req("text/xml"),
req("application/x-amz-json-1.0"),
req("application/x-amz-json-1.1"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_rest_json_1_content_type(request),
RequestRejection::MissingRestJson1ContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_rest_json_1_content_type(&req("123")), RequestRejection::MimeParse);
}
#[test]
fn validate_rest_xml_content_type() {
// Check valid content-type header.
let request = req("application/xml");
assert!(check_rest_xml_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/axml"),
req("application/xml1"),
req("applicatio/xml"),
req("text/xml"),
req("application/x-amz-json-1.0"),
req("application/x-amz-json-1.1"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_rest_xml_content_type(request),
RequestRejection::MissingRestXmlContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_rest_xml_content_type(&req("123")), RequestRejection::MimeParse);
}
#[test]
fn validate_aws_json_10_content_type() {
// Check valid content-type header.
let request = req("application/x-amz-json-1.0");
assert!(check_aws_json_10_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/x-amz-json-1."),
req("application/-amz-json-1.0"),
req("application/xml"),
req("application/json"),
req("applicatio/x-amz-json-1.0"),
req("text/xml"),
req("application/x-amz-json-1.1"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_aws_json_10_content_type(request),
RequestRejection::MissingAwsJson10ContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_aws_json_10_content_type(&req("123")), RequestRejection::MimeParse);
}
#[test]
fn validate_aws_json_11_content_type() {
// Check valid content-type header.
let request = req("application/x-amz-json-1.1");
assert!(check_aws_json_11_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/x-amz-json-1."),
req("application/-amz-json-1.1"),
req("application/xml"),
req("application/json"),
req("applicatio/x-amz-json-1.1"),
req("text/xml"),
req("application/x-amz-json-1.0"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_aws_json_11_content_type(request),
RequestRejection::MissingAwsJson11ContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_aws_json_11_content_type(&req("123")), RequestRejection::MimeParse);
} }
} }

View File

@ -134,8 +134,10 @@ pub enum RequestRejection {
HttpBody(crate::Error), HttpBody(crate::Error),
// These are used when checking the `Content-Type` header. // These are used when checking the `Content-Type` header.
MissingJsonContentType, MissingRestJson1ContentType,
MissingXmlContentType, MissingAwsJson10ContentType,
MissingAwsJson11ContentType,
MissingRestXmlContentType,
MimeParse, MimeParse,
/// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to /// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to

View File

@ -7,10 +7,15 @@
//! //!
//! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html //! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
use self::{future::RouterFuture, request_spec::RequestSpec}; use self::future::RouterFuture;
use self::request_spec::RequestSpec;
use crate::body::{boxed, Body, BoxBody, HttpBody}; use crate::body::{boxed, Body, BoxBody, HttpBody};
use crate::protocols::Protocol;
use crate::runtime_error::{RuntimeError, RuntimeErrorKind};
use crate::BoxError; use crate::BoxError;
use axum_core::response::IntoResponse;
use http::{Request, Response, StatusCode}; use http::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::{ use std::{
convert::Infallible, convert::Infallible,
task::{Context, Poll}, task::{Context, Poll},
@ -31,34 +36,60 @@ mod route;
pub use self::{into_make_service::IntoMakeService, route::Route}; pub use self::{into_make_service::IntoMakeService, route::Route};
/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s /// The router is a [`tower::Service`] that routes incoming requests to other `Service`s
/// based on the request's URI and HTTP method, adhering to the [Smithy specification]. /// based on the request's URI and HTTP method or on some specific header setting the target operation.
/// The former is adhering to the [Smithy specification], while the latter is adhering to
/// the [AwsJson specification].
///
/// The router is also [Protocol] aware and currently supports REST based protocols like [restJson1] or [restXml]
/// and RPC based protocols like [awsJson1.0] or [awsJson1.1].
/// It currently does not support Smithy's [endpoint trait]. /// It currently does not support Smithy's [endpoint trait].
/// ///
/// You should not **instantiate** this router directly; it will be created for you from the /// You should not **instantiate** this router directly; it will be created for you from the
/// code generated from your Smithy model by `smithy-rs`. /// code generated from your Smithy model by `smithy-rs`.
/// ///
/// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html /// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
/// [AwsJson specification]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#protocol-behaviors
/// [Protocol]: https://awslabs.github.io/smithy/1.0/spec/aws/index.html#aws-protocols
/// [restJson1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html
/// [restXml]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html
/// [awsJson1.0]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html
/// [awsJson1.1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html
/// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait /// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
#[derive(Debug)] #[derive(Debug)]
pub struct Router<B = Body> { pub struct Router<B = Body> {
routes: Vec<(Route<B>, RequestSpec)>, routes: Routes<B>,
}
/// Protocol-aware routes types.
///
/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the
/// request URI and we thus need to iterate the whole list and use a ranking mechanism to choose.
///
/// AwsJson 1.0 and 1.1 routes can be stored in a `HashMap` since the requested operation can be
/// directly found in the `X-Amz-Target` HTTP header.
#[derive(Debug)]
enum Routes<B = Body> {
RestXml(Vec<(Route<B>, RequestSpec)>),
RestJson1(Vec<(Route<B>, RequestSpec)>),
AwsJson10(HashMap<String, Route<B>>),
AwsJson11(HashMap<String, Route<B>>),
} }
impl<B> Clone for Router<B> { impl<B> Clone for Router<B> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { match &self.routes {
routes: self.routes.clone(), Routes::RestJson1(routes) => Router {
} routes: Routes::RestJson1(routes.clone()),
} },
} Routes::RestXml(routes) => Router {
routes: Routes::RestXml(routes.clone()),
impl<B> Default for Router<B> },
where Routes::AwsJson10(routes) => Router {
B: Send + 'static, routes: Routes::AwsJson10(routes.clone()),
{ },
fn default() -> Self { Routes::AwsJson11(routes) => Router {
Self { routes: Routes::AwsJson11(routes.clone()),
routes: Default::default(), },
} }
} }
} }
@ -67,32 +98,29 @@ impl<B> Router<B>
where where
B: Send + 'static, B: Send + 'static,
{ {
/// Create a new `Router` from a vector of pairs of request specs and services. /// Return the correct, protocol-specific "Not Found" response for an unknown operation.
/// fn unknown_operation(&self) -> RouterFuture<B> {
/// If the vector is empty the router will respond `404 Not Found` to all requests. let protocol = match &self.routes {
#[doc(hidden)] Routes::RestJson1(_) => Protocol::RestJson1,
pub fn from_box_clone_service_iter<T>(routes: T) -> Self Routes::RestXml(_) => Protocol::RestXml,
where Routes::AwsJson10(_) => Protocol::AwsJson10,
T: IntoIterator< Routes::AwsJson11(_) => Protocol::AwsJson11,
Item = ( };
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>, let error = RuntimeError {
RequestSpec, protocol,
), kind: RuntimeErrorKind::UnknownOperation,
>, };
{ RouterFuture::from_response(error.into_response())
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter()
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect();
// Sort them once by specifity, with the more specific routes sorted before the less
// specific ones, so that when routing a request we can simply iterate through the routes
// and pick the first one that matches.
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self { routes }
} }
/// Return the HTTP error response for non allowed method.
fn method_not_allowed(&self) -> RouterFuture<B> {
RouterFuture::from_response({
let mut res = Response::new(crate::body::empty());
*res.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
res
})
}
/// Convert this router into a [`MakeService`], that is a [`Service`] whose /// Convert this router into a [`MakeService`], that is a [`Service`] whose
/// response is another service. /// response is another service.
/// ///
@ -124,12 +152,146 @@ where
.layer_fn(Route::new) .layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(boxed)) .layer(MapResponseBodyLayer::new(boxed))
.layer(layer); .layer(layer);
let routes = self match self.routes {
.routes Routes::RestJson1(routes) => {
let routes = routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.collect();
Router {
routes: Routes::RestJson1(routes),
}
}
Routes::RestXml(routes) => {
let routes = routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.collect();
Router {
routes: Routes::RestXml(routes),
}
}
Routes::AwsJson10(routes) => {
let routes = routes
.into_iter()
.map(|(operation, route)| (operation, Layer::layer(&layer, route)))
.collect();
Router {
routes: Routes::AwsJson10(routes),
}
}
Routes::AwsJson11(routes) => {
let routes = routes
.into_iter()
.map(|(operation, route)| (operation, Layer::layer(&layer, route)))
.collect();
Router {
routes: Routes::AwsJson11(routes),
}
}
}
}
/// Create a new RestJson1 `Router` from an iterator over pairs of [`RequestSpec`]s and services.
///
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_rest_json_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
RequestSpec,
),
>,
{
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter() .into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)) .map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect(); .collect();
Router { routes }
// Sort them once by specifity, with the more specific routes sorted before the less
// specific ones, so that when routing a request we can simply iterate through the routes
// and pick the first one that matches.
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self {
routes: Routes::RestJson1(routes),
}
}
/// Create a new RestXml `Router` from an iterator over pairs of [`RequestSpec`]s and services.
///
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_rest_xml_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
RequestSpec,
),
>,
{
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter()
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect();
// Sort them once by specifity, with the more specific routes sorted before the less
// specific ones, so that when routing a request we can simply iterate through the routes
// and pick the first one that matches.
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self {
routes: Routes::RestXml(routes),
}
}
/// Create a new AwsJson 1.0 `Router` from an iterator over pairs of operation names and services.
///
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_aws_json_10_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
String,
),
>,
{
let routes = routes
.into_iter()
.map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
.collect();
Self {
routes: Routes::AwsJson10(routes),
}
}
/// Create a new AwsJson 1.1 `Router` from a vector of pairs of operations and services.
///
/// If the vector is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_aws_json_11_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
String,
),
>,
{
let routes = routes
.into_iter()
.map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
.collect();
Self {
routes: Routes::AwsJson11(routes),
}
} }
} }
@ -148,44 +310,84 @@ where
#[inline] #[inline]
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let mut method_not_allowed = false; match &self.routes {
// REST routes.
Routes::RestJson1(routes) | Routes::RestXml(routes) => {
let mut method_not_allowed = false;
for (route, request_spec) in &self.routes { // Loop through all the routes and validate if any of them matches. Routes are already ranked.
match request_spec.matches(&req) { for (route, request_spec) in routes {
request_spec::Match::Yes => { match request_spec.matches(&req) {
return RouterFuture::from_oneshot(route.clone().oneshot(req)); request_spec::Match::Yes => {
return RouterFuture::from_oneshot(route.clone().oneshot(req));
}
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
// Continue looping to see if another route matches.
request_spec::Match::No => continue,
}
} }
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
// Continue looping to see if another route matches. if method_not_allowed {
request_spec::Match::No => continue, // The HTTP method is not correct.
self.method_not_allowed()
} else {
// In any other case return the `RuntimeError::UnknownOperation`.
self.unknown_operation()
}
}
// AwsJson routes.
Routes::AwsJson10(routes) | Routes::AwsJson11(routes) => {
if req.uri() == "/" {
// Check the request method for POST.
if req.method() == http::Method::POST {
// Find the `x-amz-target` header.
if let Some(target) = req.headers().get("x-amz-target") {
if let Ok(target) = target.to_str() {
// Lookup in the `HashMap` for a route for the target.
let route = routes.get(target);
if let Some(route) = route {
return RouterFuture::from_oneshot(route.clone().oneshot(req));
}
}
}
} else {
// The HTTP method is not POST.
return self.method_not_allowed();
}
}
// In any other case return the `RuntimeError::UnknownOperation`.
self.unknown_operation()
} }
} }
let status_code = if method_not_allowed {
StatusCode::METHOD_NOT_ALLOWED
} else {
StatusCode::NOT_FOUND
};
RouterFuture::from_response(
Response::builder()
.status(status_code)
.body(crate::body::empty())
.unwrap(),
)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod rest_tests {
use super::*; use super::*;
use crate::{body::boxed, routing::request_spec::*}; use crate::{body::boxed, routing::request_spec::*};
use futures_util::Future; use futures_util::Future;
use http::Method; use http::{HeaderMap, Method};
use std::pin::Pin; use std::pin::Pin;
/// Helper function to build a `Request`. Used in other test modules. /// Helper function to build a `Request`. Used in other test modules.
pub fn req(method: &Method, uri: &str) -> Request<()> { pub fn req(method: &Method, uri: &str, headers: Option<HeaderMap>) -> Request<()> {
Request::builder().method(method).uri(uri).body(()).unwrap() let mut r = Request::builder().method(method).uri(uri).body(()).unwrap();
if let Some(headers) = headers {
*r.headers_mut() = headers
}
r
}
// Returns a `Response`'s body as a `String`, without consuming the response.
pub async fn get_body_as_string<B>(res: &mut Response<B>) -> String
where
B: http_body::Body + std::marker::Unpin,
B::Error: std::fmt::Debug,
{
let body_mut = res.body_mut();
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
String::from(std::str::from_utf8(&body_bytes).unwrap())
} }
/// A service that returns its name and the request's URI in the response body. /// A service that returns its name and the request's URI in the response body.
@ -210,17 +412,6 @@ mod tests {
} }
} }
// Returns a `Response`'s body as a `String`, without consuming the response.
async fn get_body_as_str<B>(res: &mut Response<B>) -> String
where
B: http_body::Body + std::marker::Unpin,
B::Error: std::fmt::Debug,
{
let body_mut = res.body_mut();
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
String::from(std::str::from_utf8(&body_bytes).unwrap())
}
// This test is a rewrite of `mux.spec.ts`. // This test is a rewrite of `mux.spec.ts`.
// https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts // https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts
#[tokio::test] #[tokio::test]
@ -271,55 +462,64 @@ mod tests {
), ),
]; ];
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| { // Test both RestJson1 and RestXml routers.
let router_json = Router::new_rest_json_router(request_specs.clone().into_iter().map(|(spec, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec,
)
}));
let router_xml = Router::new_rest_xml_router(request_specs.into_iter().map(|(spec, svc_name)| {
( (
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec, spec,
) )
})); }));
let hits = vec![ for mut router in [router_json, router_xml] {
("A", Method::GET, "/a/b/c"), let hits = vec![
("MiddleGreedy", Method::GET, "/mg/a/z"), ("A", Method::GET, "/a/b/c"),
("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), ("MiddleGreedy", Method::GET, "/mg/a/z"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux"), ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"),
("Delete", Method::DELETE, "/?foo=bar&baz"), ("Delete", Method::DELETE, "/?foo=bar&baz=quux"),
("Delete", Method::DELETE, "/?foo=bar&baz=&"), ("Delete", Method::DELETE, "/?foo=bar&baz"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), ("Delete", Method::DELETE, "/?foo=bar&baz=&"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo"), ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo="), ("QueryKeyOnly", Method::POST, "/query_key_only?foo"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), ("QueryKeyOnly", Method::POST, "/query_key_only?foo="),
]; ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"),
for (svc_name, method, uri) in &hits { ];
let mut res = router.call(req(method, uri)).await.unwrap(); for (svc_name, method, uri) in &hits {
let actual_body = get_body_as_str(&mut res).await; let mut res = router.call(req(method, uri, None)).await.unwrap();
let actual_body = get_body_as_string(&mut res).await;
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
} }
for (_, _, uri) in hits { for (_, _, uri) in hits {
let res = router.call(req(&Method::PATCH, uri)).await.unwrap(); let res = router.call(req(&Method::PATCH, uri, None)).await.unwrap();
assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status()); assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status());
} }
let misses = vec![ let misses = vec![
(Method::GET, "/a"), (Method::GET, "/a"),
(Method::GET, "/a/b"), (Method::GET, "/a/b"),
(Method::GET, "/mg"), (Method::GET, "/mg"),
(Method::GET, "/mg/q"), (Method::GET, "/mg/q"),
(Method::GET, "/mg/z"), (Method::GET, "/mg/z"),
(Method::GET, "/mg/a/b/z/c"), (Method::GET, "/mg/a/b/z/c"),
(Method::DELETE, "/?foo=bar"), (Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?foo=bar"), (Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?baz=quux"), (Method::DELETE, "/?baz=quux"),
(Method::POST, "/query_key_only?baz=quux"), (Method::POST, "/query_key_only?baz=quux"),
(Method::GET, "/"), (Method::GET, "/"),
(Method::POST, "/"), (Method::POST, "/"),
]; ];
for (method, miss) in misses { for (method, miss) in misses {
let res = router.call(req(&method, miss)).await.unwrap(); let res = router.call(req(&method, miss, None)).await.unwrap();
assert_eq!(StatusCode::NOT_FOUND, res.status()); assert_eq!(StatusCode::NOT_FOUND, res.status());
}
} }
} }
@ -364,7 +564,7 @@ mod tests {
), ),
]; ];
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| { let mut router = Router::new_rest_json_router(request_specs.into_iter().map(|(spec, svc_name)| {
( (
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec, spec,
@ -378,10 +578,96 @@ mod tests {
("B2", Method::GET, "/b/foo?q=baz"), ("B2", Method::GET, "/b/foo?q=baz"),
]; ];
for (svc_name, method, uri) in &hits { for (svc_name, method, uri) in &hits {
let mut res = router.call(req(method, uri)).await.unwrap(); let mut res = router.call(req(method, uri, None)).await.unwrap();
let actual_body = get_body_as_str(&mut res).await; let actual_body = get_body_as_string(&mut res).await;
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
} }
} }
} }
#[cfg(test)]
mod awsjson_tests {
use super::rest_tests::{get_body_as_string, req};
use super::*;
use crate::body::boxed;
use futures_util::Future;
use http::{HeaderMap, HeaderValue, Method};
use pretty_assertions::assert_eq;
use std::pin::Pin;
/// A service that returns its name and the request's URI in the response body.
#[derive(Clone)]
struct NamedEchoOperationService(String);
impl<B> Service<Request<B>> for NamedEchoOperationService {
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
let target = req
.headers()
.get("x-amz-target")
.map(|x| x.to_str().unwrap())
.unwrap_or("unknown");
let body = boxed(Body::from(format!("{} :: {}", self.0, target)));
let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) };
Box::pin(fut)
}
}
#[tokio::test]
async fn simple_routing() {
let routes = vec![("Service.Operation", "A")];
let router_json10 = Router::new_aws_json_10_router(routes.clone().into_iter().map(|(operation, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))),
operation.to_string(),
)
}));
let router_json11 = Router::new_aws_json_11_router(routes.into_iter().map(|(operation, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))),
operation.to_string(),
)
}));
for mut router in [router_json10, router_json11] {
let mut headers = HeaderMap::new();
headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation"));
// Valid request, should return a valid body.
let mut res = router
.call(req(&Method::POST, "/", Some(headers.clone())))
.await
.unwrap();
let actual_body = get_body_as_string(&mut res).await;
assert_eq!(format!("{} :: {}", "A", "Service.Operation"), actual_body);
// No headers, should return NOT_FOUND.
let res = router.call(req(&Method::POST, "/", None)).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
// Wrong HTTP method, should return METHOD_NOT_ALLOWED.
let res = router
.call(req(&Method::GET, "/", Some(headers.clone())))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
// Wrong URI, should return NOT_FOUND.
let res = router
.call(req(&Method::POST, "/something", Some(headers)))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
}
}

View File

@ -243,7 +243,7 @@ impl RequestSpec {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::tests::req; use super::super::rest_tests::req;
use super::*; use super::*;
use http::Method; use http::Method;
@ -295,7 +295,7 @@ mod tests {
(Method::GET, "/mg/a/z/z/z"), (Method::GET, "/mg/a/z/z/z"),
]; ];
for (method, uri) in &hits { for (method, uri) in &hits {
assert_eq!(Match::Yes, spec.matches(&req(method, uri))); assert_eq!(Match::Yes, spec.matches(&req(method, uri, None)));
} }
} }
@ -309,7 +309,7 @@ mod tests {
(Method::DELETE, "/?foo&foo"), (Method::DELETE, "/?foo&foo"),
]; ];
for (method, uri) in &hits { for (method, uri) in &hits {
assert_eq!(Match::Yes, spec.matches(&req(method, uri))); assert_eq!(Match::Yes, spec.matches(&req(method, uri, None)));
} }
} }
@ -325,7 +325,7 @@ mod tests {
fn repeated_query_keys_same_values_match() { fn repeated_query_keys_same_values_match() {
assert_eq!( assert_eq!(
Match::Yes, Match::Yes,
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar")) key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar", None))
); );
} }
@ -333,7 +333,7 @@ mod tests {
fn repeated_query_keys_distinct_values_does_not_match() { fn repeated_query_keys_distinct_values_does_not_match() {
assert_eq!( assert_eq!(
Match::No, Match::No,
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz")) key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz", None))
); );
} }
@ -354,11 +354,11 @@ mod tests {
#[test] #[test]
fn empty_segments_in_the_middle_do_matter() { fn empty_segments_in_the_middle_do_matter() {
assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b"))); assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b", None)));
let misses = vec![(Method::GET, "/a//b"), (Method::GET, "//////a//b")]; let misses = vec![(Method::GET, "/a//b"), (Method::GET, "//////a//b")];
for (method, uri) in &misses { for (method, uri) in &misses {
assert_eq!(Match::No, ab_spec().matches(&req(method, uri))); assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None)));
} }
} }
@ -379,10 +379,10 @@ mod tests {
(Method::GET, "/a//b"), // Label is bound to `""`. (Method::GET, "/a//b"), // Label is bound to `""`.
]; ];
for (method, uri) in &hits { for (method, uri) in &hits {
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri))); assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None)));
} }
assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b"))); assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b", None)));
} }
#[test] #[test]
@ -403,7 +403,7 @@ mod tests {
(Method::GET, "/a///a//b///suffix"), (Method::GET, "/a///a//b///suffix"),
]; ];
for (method, uri) in &hits { for (method, uri) in &hits {
assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri))); assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri, None)));
} }
} }
@ -418,7 +418,7 @@ mod tests {
(Method::GET, "//a//b////"), (Method::GET, "//a//b////"),
]; ];
for (method, uri) in &misses { for (method, uri) in &misses {
assert_eq!(Match::No, ab_spec().matches(&req(method, uri))); assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None)));
} }
} }
@ -432,13 +432,13 @@ mod tests {
let misses = vec![(Method::GET, "/a"), (Method::GET, "/a//"), (Method::GET, "/a///")]; let misses = vec![(Method::GET, "/a"), (Method::GET, "/a//"), (Method::GET, "/a///")];
for (method, uri) in &misses { for (method, uri) in &misses {
assert_eq!(Match::No, label_spec.matches(&req(method, uri))); assert_eq!(Match::No, label_spec.matches(&req(method, uri, None)));
} }
// In the second example, the label is bound to `""`. // In the second example, the label is bound to `""`.
let hits = vec![(Method::GET, "/a/label"), (Method::GET, "/a/")]; let hits = vec![(Method::GET, "/a/label"), (Method::GET, "/a/")];
for (method, uri) in &hits { for (method, uri) in &hits {
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri))); assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None)));
} }
} }
} }

View File

@ -25,7 +25,8 @@ use crate::protocols::Protocol;
#[derive(Debug)] #[derive(Debug)]
pub enum RuntimeErrorKind { pub enum RuntimeErrorKind {
// UnknownOperation, /// The requested operation does not exist.
UnknownOperation,
/// Request failed to deserialize or response failed to serialize. /// Request failed to deserialize or response failed to serialize.
Serialization(crate::Error), Serialization(crate::Error),
/// As of writing, this variant can only occur upon failure to extract an /// As of writing, this variant can only occur upon failure to extract an
@ -43,6 +44,7 @@ impl RuntimeErrorKind {
match self { match self {
RuntimeErrorKind::Serialization(_) => "SerializationException", RuntimeErrorKind::Serialization(_) => "SerializationException",
RuntimeErrorKind::InternalFailure(_) => "InternalFailureException", RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
RuntimeErrorKind::UnknownOperation => "UnknownOperation",
} }
} }
} }
@ -58,11 +60,16 @@ impl axum_core::response::IntoResponse for RuntimeError {
let status_code = match self.kind { let status_code = match self.kind {
RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST, RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR, RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
RuntimeErrorKind::UnknownOperation => http::StatusCode::NOT_FOUND,
}; };
let body = crate::body::to_boxed(match self.protocol { let body = crate::body::to_boxed(match self.protocol {
Protocol::RestJson1 => "{}", Protocol::RestJson1 => "{}",
Protocol::RestXml => "", Protocol::RestXml => "",
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
Protocol::AwsJson10 => "",
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
Protocol::AwsJson11 => "",
}); });
let mut builder = http::Response::builder(); let mut builder = http::Response::builder();
@ -74,9 +81,9 @@ impl axum_core::response::IntoResponse for RuntimeError {
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("X-Amzn-Errortype", self.kind.name()); .header("X-Amzn-Errortype", self.kind.name());
} }
Protocol::RestXml => { Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"),
builder = builder.header("Content-Type", "application/xml"); Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"),
} Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"),
} }
builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from( builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from(