Initial awsJson 1.0 and 1.1 server implementation (#1279)

To support these new protocols, we also changed the runtime Router definition as for both awsJson 1.0 and 1.1, every request MUST be sent to the root URL (/) using the HTTP "POST" method.

The runtime Router now supports explicitly all the available protocols.

There are still some protocol tests failing for awsJson 1.0 and 1.1. The failure are caused by the missing implementation of @endpoint trait and date parsing. Protocol tests for these 2 protocols are heavily biased towards Responses. Before announcing support for awsJson 1.0 and 1.1, we should increase the protocol tests coverage.

Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: david-perez <d@vidp.dev>
This commit is contained in:
Matteo Bigoi 2022-04-19 17:41:03 +01:00 committed by GitHub
parent a870d2ad05
commit 2931c9e1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 826 additions and 198 deletions

View File

@ -36,6 +36,8 @@ val allCodegenTests = listOf(
CodegenTest("com.amazonaws.simple#SimpleService", "simple"),
CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"),
CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"),
CodegenTest("aws.protocoltests.misc#MiscService", "misc"),
CodegenTest("com.amazonaws.ebs#Ebs", "ebs"),
CodegenTest("com.amazonaws.s3#AmazonS3", "s3"),

View File

@ -38,4 +38,7 @@ object ServerRuntimeType {
fun ResponseRejection(runtimeConfig: RuntimeConfig) =
RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection")
fun Protocol(runtimeConfig: RuntimeConfig) =
RuntimeType("Protocol", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols")
}

View File

@ -5,6 +5,10 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
@ -28,7 +32,9 @@ class ServerOperationRegistryGenerator(
private val httpBindingResolver: HttpBindingResolver,
private val operations: List<OperationShape>,
) {
private val protocol = codegenContext.protocol
private val symbolProvider = codegenContext.symbolProvider
private val serviceName = codegenContext.serviceShape.toShapeId().name
private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf(
@ -223,7 +229,7 @@ class ServerOperationRegistryGenerator(
rustTemplate(
"""
$requestSpecs
#{Router}::from_box_clone_service_iter($towerServices)
#{Router}::${runtimeRouterConstructor()}($towerServices)
""".trimIndent(),
*codegenScope
)
@ -241,12 +247,42 @@ class ServerOperationRegistryGenerator(
}
/*
* Generate the `RequestSpec`s for an operation based on its HTTP-bound route.
* Finds the runtime function to construct a new `Router` based on the Protocol.
*/
private fun OperationShape.requestSpec(): String {
private fun runtimeRouterConstructor(): String =
when (protocol) {
RestJson1Trait.ID -> "new_rest_json_router"
RestXmlTrait.ID -> "new_rest_xml_router"
AwsJson1_0Trait.ID -> "new_aws_json_10_router"
AwsJson1_1Trait.ID -> "new_aws_json_11_router"
else -> TODO("Protocol $protocol not supported yet")
}
/*
* Returns the `RequestSpec`s for an operation based on its HTTP-bound route.
*/
private fun OperationShape.requestSpec(): String =
when (protocol) {
RestJson1Trait.ID, RestXmlTrait.ID -> restRequestSpec()
AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> awsJsonOperationName()
else -> TODO("Protocol $protocol not supported yet")
}
/*
* Returns an AwsJson specific runtime `RequestSpec`.
*/
private fun OperationShape.awsJsonOperationName(): String {
val operationName = symbolProvider.toSymbol(this).name
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
return """String::from("$serviceName.$operationName")"""
}
/*
* Generates a REST (RestJson1, RestXml) specific runtime `RequestSpec`.
*/
private fun OperationShape.restRequestSpec(): String {
val httpTrait = httpBindingResolver.httpTrait(this)
val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName()
// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait.
val pathSegments = httpTrait.uri.segments.map {
"$namespace::PathSegment::" +
@ -268,7 +304,7 @@ class ServerOperationRegistryGenerator(
$namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
$namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
)
)
),
)
""".trimIndent()
}

View File

@ -437,8 +437,8 @@ class ServerProtocolTestGenerator(
else -> {
rustWriter.rustTemplate(
"""
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""",
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""",
*codegenScope
)
}
@ -798,6 +798,18 @@ class ServerProtocolTestGenerator(
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request),
FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request),
FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request),
// AwsJson1.0 failing tests.
FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request),
FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request),
// AwsJson1.1 failing tests.
FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTraitWithHostLabel", TestType.Request),
FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTrait", TestType.Request),
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_httpdate_timestamps", TestType.Response),
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_iso8601_timestamps", TestType.Response),
FailingTest("aws.protocoltests.json#JsonProtocol", "parses_the_request_id_from_the_response", TestType.Response),
)
private val RunOnly: Set<String>? = null

View File

@ -0,0 +1,97 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.escape
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.awsJsonFieldName
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonCustomization
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSection
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.hasTrait
/*
* AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator]
* with AwsJson specific configurations.
*/
class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = ServerAwsJson(codegenContext, version)
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))
override fun transformModel(model: Model): Model = model
override fun support(): ProtocolSupport {
return ProtocolSupport(
/* Client support */
requestSerialization = false,
requestBodySerialization = false,
responseDeserialization = false,
errorDeserialization = false,
/* Server support */
requestDeserialization = true,
requestBodyDeserialization = true,
responseSerialization = true,
errorSerialization = true
)
}
}
/**
* AwsJson requires errors to be serialized with an additional "__type" field. This
* customization writes the right field depending on the version of the AwsJson protocol.
*/
class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCustomization() {
override fun section(section: JsonSection): Writable = when (section) {
is JsonSection.ServerError -> writable {
if (section.structureShape.hasTrait<ErrorTrait>()) {
val typeId = when (awsJsonVersion) {
// AwsJson 1.0 wants the whole shape ID (namespace#Shape).
// https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization
AwsJsonVersion.Json10 -> section.structureShape.id.toString()
// AwsJson 1.1 wants only the shape name (Shape).
// https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#operation-error-serialization
AwsJsonVersion.Json11 -> section.structureShape.id.name.toString()
}
rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""")
}
}
}
}
/**
* AwsJson requires errors to be serialized with an additional "__type" field. This class
* customizes [JsonSerializerGenerator] to add this functionality.
*/
class ServerAwsJsonSerializerGenerator(
private val codegenContext: CodegenContext,
private val httpBindingResolver: HttpBindingResolver,
private val awsJsonVersion: AwsJsonVersion,
private val jsonSerializerGenerator: JsonSerializerGenerator =
JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName, customizations = listOf(ServerAwsJsonError(awsJsonVersion)))
) : StructuredDataSerializerGenerator by jsonSerializerGenerator
class ServerAwsJson(
private val codegenContext: CodegenContext,
private val awsJsonVersion: AwsJsonVersion
) : AwsJson(codegenContext, awsJsonVersion) {
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerAwsJsonSerializerGenerator(codegenContext, httpBindingResolver, awsJsonVersion)
}

View File

@ -5,6 +5,8 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
@ -1064,10 +1066,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
private fun getContentTypeCheck(): String {
when (codegenContext.protocol) {
RestJson1Trait.ID -> {
return "check_json_content_type"
return "check_rest_json_1_content_type"
}
RestXmlTrait.ID -> {
return "check_xml_content_type"
return "check_rest_xml_content_type"
}
AwsJson1_0Trait.ID -> {
return "check_aws_json_10_content_type"
}
AwsJson1_1Trait.ID -> {
return "check_aws_json_11_content_type"
}
else -> {
TODO("Protocol ${codegenContext.protocol} not supported yet")
@ -1086,7 +1094,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
return ServerRuntimeType.RequestRejection(runtimeConfig)
}
when (codegenContext.protocol) {
RestJson1Trait.ID -> {
RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
return CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error")
}
RestXmlTrait.ID -> {

View File

@ -5,6 +5,8 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException
@ -14,6 +16,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
@ -39,6 +42,8 @@ class ServerProtocolLoader(private val supportedProtocols: ProtocolMap) {
val DefaultProtocols = mapOf(
RestJson1Trait.ID to ServerRestJsonFactory(),
RestXmlTrait.ID to ServerRestXmlFactory(),
AwsJson1_0Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json10),
AwsJson1_1Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json11),
)
}
}

View File

@ -16,9 +16,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml
* RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
* with RestXml specific configurations.
*/
class ServerRestXmlFactory(private val generator: (CodegenContext) -> Protocol = { RestXml(it) }) :
ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext)
class ServerRestXmlFactory : ProtocolGeneratorFactory<ServerHttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = RestXml(codegenContext)
override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator =
ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext))

View File

@ -128,7 +128,7 @@ class AwsJsonSerializerGenerator(
}
}
class AwsJson(
open class AwsJson(
private val codegenContext: CodegenContext,
awsJsonVersion: AwsJsonVersion
) : Protocol {
@ -183,6 +183,4 @@ class AwsJson(
}
}
private fun awsJsonFieldName(member: MemberShape): String {
return member.memberName
}
fun awsJsonFieldName(member: MemberShape): String = member.memberName

View File

@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator
import software.amazon.smithy.rust.codegen.smithy.customize.Section
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.smithy.generators.serializationError
@ -49,11 +51,25 @@ import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
/**
* Class describing a JSON section that can be used in a customization.
*/
sealed class JsonSection(name: String) : Section(name) {
/** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */
data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSection("ServerError")
}
/**
* JSON customization.
*/
typealias JsonCustomization = NamedSectionGenerator<JsonSection>
class JsonSerializerGenerator(
codegenContext: CodegenContext,
private val httpBindingResolver: HttpBindingResolver,
/** Function that maps a MemberShape into a JSON field name */
private val jsonName: (MemberShape) -> String,
private val customizations: List<JsonCustomization> = listOf(),
) : StructuredDataSerializerGenerator {
private data class Context<T : Shape>(
/** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */
@ -155,7 +171,7 @@ class JsonSerializerGenerator(
private fun serverStructureSerializer(
fnName: String,
structureShape: StructureShape,
includedMembers: List<MemberShape>
includedMembers: List<MemberShape>,
): RuntimeType {
return RuntimeType.forInlineFun(fnName, operationSerModule) {
it.rustBlockTemplate(
@ -166,6 +182,7 @@ class JsonSerializerGenerator(
rust("let mut out = String::new();")
rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope)
serializeStructure(StructContext("object", "value", structureShape), includedMembers)
customizations.forEach { it.section(JsonSection.ServerError(structureShape, "object"))(this) }
rust("object.finish();")
rustTemplate("Ok(out)", *codegenScope)
}

View File

@ -29,6 +29,7 @@ http-body = "0.4"
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
mime = "0.3"
nom = "7"
paste = "1"
pin-project-lite = "0.2"
regex = "1.0"
serde_urlencoded = "0.7"

View File

@ -6,51 +6,206 @@
//! Protocol helpers.
use crate::rejection::RequestRejection;
use axum_core::extract::RequestParts;
use paste::paste;
#[derive(Debug)]
/// Supported protocols.
#[derive(Debug, Clone, Copy)]
pub enum Protocol {
RestJson1,
RestXml,
AwsJson10,
AwsJson11,
}
/// Validate that the request had the standard JSON content-type header.
pub fn check_json_content_type<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
let mime = req
.headers()
.ok_or(RequestRejection::MissingJsonContentType)?
.get(http::header::CONTENT_TYPE)
.ok_or(RequestRejection::MissingJsonContentType)?
.to_str()
.map_err(|_| RequestRejection::MissingJsonContentType)?
.parse::<mime::Mime>()
.map_err(|_| RequestRejection::MimeParse)?;
/// Implement the content-type header validation for a request.
macro_rules! impl_content_type_validation {
($name:literal, $type: literal, $subtype:literal, $rejection:path) => {
paste! {
#[doc = concat!("Validates that the request has the standard `", $type, "/", $subtype, "` content-type header.")]
pub fn [<check_ $name _content_type>]<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
let mime = req
.headers()
.ok_or($rejection)?
.get(http::header::CONTENT_TYPE)
.ok_or($rejection)?
.to_str()
.map_err(|_| $rejection)?
.parse::<mime::Mime>()
.map_err(|_| RequestRejection::MimeParse)?;
if mime.type_() == $type && mime.subtype() == $subtype {
Ok(())
} else {
Err($rejection)
}
}
}
};
}
if mime.type_() == "application"
&& (mime.subtype() == "json" || mime.suffix().filter(|name| *name == "json").is_some())
{
Ok(())
} else {
Err(RequestRejection::MissingJsonContentType)
}
}
/// Validate that the request had the standard XML content-type header.
pub fn check_xml_content_type<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
let mime = req
.headers()
.ok_or(RequestRejection::MissingXmlContentType)?
.get(http::header::CONTENT_TYPE)
.ok_or(RequestRejection::MissingXmlContentType)?
.to_str()
.map_err(|_| RequestRejection::MissingXmlContentType)?
.parse::<mime::Mime>()
.map_err(|_| RequestRejection::MimeParse)?;
if mime.type_() == "application"
&& (mime.subtype() == "xml" || mime.suffix().filter(|name| *name == "xml").is_some())
{
Ok(())
} else {
Err(RequestRejection::MissingXmlContentType)
impl_content_type_validation!(
"rest_json_1",
"application",
"json",
RequestRejection::MissingRestJson1ContentType
);
impl_content_type_validation!(
"rest_xml",
"application",
"xml",
RequestRejection::MissingRestXmlContentType
);
impl_content_type_validation!(
"aws_json_10",
"application",
"x-amz-json-1.0",
RequestRejection::MissingAwsJson10ContentType
);
impl_content_type_validation!(
"aws_json_11",
"application",
"x-amz-json-1.1",
RequestRejection::MissingAwsJson11ContentType
);
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
fn req(content_type: &str) -> RequestParts<&str> {
let request = Request::builder()
.header("content-type", content_type)
.body("")
.unwrap();
RequestParts::new(request)
}
/// This macro validates the rejection type since we cannot implement `PartialEq`
/// for `RequestRejection` as it is based on the crate error type, which uses
/// `axum_core::BoxError`.
macro_rules! validate_rejection_type {
($result:expr, $rejection:path) => {
match $result {
Ok(()) => panic!("Content-type validation is expected to fail"),
Err(e) => match e {
$rejection => {}
_ => panic!("Error {} should be {}", e.to_string(), stringify!($rejection)),
},
}
};
}
#[test]
fn validate_rest_json_1_content_type() {
// Check valid content-type header.
let request = req("application/json");
assert!(check_rest_json_1_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/ajson"),
req("application/json1"),
req("applicatio/json"),
req("application/xml"),
req("text/xml"),
req("application/x-amz-json-1.0"),
req("application/x-amz-json-1.1"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_rest_json_1_content_type(request),
RequestRejection::MissingRestJson1ContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_rest_json_1_content_type(&req("123")), RequestRejection::MimeParse);
}
#[test]
fn validate_rest_xml_content_type() {
// Check valid content-type header.
let request = req("application/xml");
assert!(check_rest_xml_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/axml"),
req("application/xml1"),
req("applicatio/xml"),
req("text/xml"),
req("application/x-amz-json-1.0"),
req("application/x-amz-json-1.1"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_rest_xml_content_type(request),
RequestRejection::MissingRestXmlContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_rest_xml_content_type(&req("123")), RequestRejection::MimeParse);
}
#[test]
fn validate_aws_json_10_content_type() {
// Check valid content-type header.
let request = req("application/x-amz-json-1.0");
assert!(check_aws_json_10_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/x-amz-json-1."),
req("application/-amz-json-1.0"),
req("application/xml"),
req("application/json"),
req("applicatio/x-amz-json-1.0"),
req("text/xml"),
req("application/x-amz-json-1.1"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_aws_json_10_content_type(request),
RequestRejection::MissingAwsJson10ContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_aws_json_10_content_type(&req("123")), RequestRejection::MimeParse);
}
#[test]
fn validate_aws_json_11_content_type() {
// Check valid content-type header.
let request = req("application/x-amz-json-1.1");
assert!(check_aws_json_11_content_type(&request).is_ok());
// Check invalid content-type header.
let invalid = vec![
req("application/x-amz-json-1."),
req("application/-amz-json-1.1"),
req("application/xml"),
req("application/json"),
req("applicatio/x-amz-json-1.1"),
req("text/xml"),
req("application/x-amz-json-1.0"),
RequestParts::new(Request::builder().body("").unwrap()),
];
for request in &invalid {
validate_rejection_type!(
check_aws_json_11_content_type(request),
RequestRejection::MissingAwsJson11ContentType
);
}
// Check request with not parsable content-type header.
validate_rejection_type!(check_aws_json_11_content_type(&req("123")), RequestRejection::MimeParse);
}
}

View File

@ -134,8 +134,10 @@ pub enum RequestRejection {
HttpBody(crate::Error),
// These are used when checking the `Content-Type` header.
MissingJsonContentType,
MissingXmlContentType,
MissingRestJson1ContentType,
MissingAwsJson10ContentType,
MissingAwsJson11ContentType,
MissingRestXmlContentType,
MimeParse,
/// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to

View File

@ -7,10 +7,15 @@
//!
//! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
use self::{future::RouterFuture, request_spec::RequestSpec};
use self::future::RouterFuture;
use self::request_spec::RequestSpec;
use crate::body::{boxed, Body, BoxBody, HttpBody};
use crate::protocols::Protocol;
use crate::runtime_error::{RuntimeError, RuntimeErrorKind};
use crate::BoxError;
use axum_core::response::IntoResponse;
use http::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::{
convert::Infallible,
task::{Context, Poll},
@ -31,34 +36,60 @@ mod route;
pub use self::{into_make_service::IntoMakeService, route::Route};
/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s
/// based on the request's URI and HTTP method, adhering to the [Smithy specification].
/// based on the request's URI and HTTP method or on some specific header setting the target operation.
/// The former is adhering to the [Smithy specification], while the latter is adhering to
/// the [AwsJson specification].
///
/// The router is also [Protocol] aware and currently supports REST based protocols like [restJson1] or [restXml]
/// and RPC based protocols like [awsJson1.0] or [awsJson1.1].
/// It currently does not support Smithy's [endpoint trait].
///
/// You should not **instantiate** this router directly; it will be created for you from the
/// code generated from your Smithy model by `smithy-rs`.
///
/// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
/// [AwsJson specification]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#protocol-behaviors
/// [Protocol]: https://awslabs.github.io/smithy/1.0/spec/aws/index.html#aws-protocols
/// [restJson1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html
/// [restXml]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html
/// [awsJson1.0]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html
/// [awsJson1.1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html
/// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
#[derive(Debug)]
pub struct Router<B = Body> {
routes: Vec<(Route<B>, RequestSpec)>,
routes: Routes<B>,
}
/// Protocol-aware routes types.
///
/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the
/// request URI and we thus need to iterate the whole list and use a ranking mechanism to choose.
///
/// AwsJson 1.0 and 1.1 routes can be stored in a `HashMap` since the requested operation can be
/// directly found in the `X-Amz-Target` HTTP header.
#[derive(Debug)]
enum Routes<B = Body> {
RestXml(Vec<(Route<B>, RequestSpec)>),
RestJson1(Vec<(Route<B>, RequestSpec)>),
AwsJson10(HashMap<String, Route<B>>),
AwsJson11(HashMap<String, Route<B>>),
}
impl<B> Clone for Router<B> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
}
}
}
impl<B> Default for Router<B>
where
B: Send + 'static,
{
fn default() -> Self {
Self {
routes: Default::default(),
match &self.routes {
Routes::RestJson1(routes) => Router {
routes: Routes::RestJson1(routes.clone()),
},
Routes::RestXml(routes) => Router {
routes: Routes::RestXml(routes.clone()),
},
Routes::AwsJson10(routes) => Router {
routes: Routes::AwsJson10(routes.clone()),
},
Routes::AwsJson11(routes) => Router {
routes: Routes::AwsJson11(routes.clone()),
},
}
}
}
@ -67,32 +98,29 @@ impl<B> Router<B>
where
B: Send + 'static,
{
/// Create a new `Router` from a vector of pairs of request specs and services.
///
/// If the vector is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn from_box_clone_service_iter<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
RequestSpec,
),
>,
{
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter()
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect();
// Sort them once by specifity, with the more specific routes sorted before the less
// specific ones, so that when routing a request we can simply iterate through the routes
// and pick the first one that matches.
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self { routes }
/// Return the correct, protocol-specific "Not Found" response for an unknown operation.
fn unknown_operation(&self) -> RouterFuture<B> {
let protocol = match &self.routes {
Routes::RestJson1(_) => Protocol::RestJson1,
Routes::RestXml(_) => Protocol::RestXml,
Routes::AwsJson10(_) => Protocol::AwsJson10,
Routes::AwsJson11(_) => Protocol::AwsJson11,
};
let error = RuntimeError {
protocol,
kind: RuntimeErrorKind::UnknownOperation,
};
RouterFuture::from_response(error.into_response())
}
/// Return the HTTP error response for non allowed method.
fn method_not_allowed(&self) -> RouterFuture<B> {
RouterFuture::from_response({
let mut res = Response::new(crate::body::empty());
*res.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
res
})
}
/// Convert this router into a [`MakeService`], that is a [`Service`] whose
/// response is another service.
///
@ -124,12 +152,146 @@ where
.layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(boxed))
.layer(layer);
let routes = self
.routes
match self.routes {
Routes::RestJson1(routes) => {
let routes = routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.collect();
Router {
routes: Routes::RestJson1(routes),
}
}
Routes::RestXml(routes) => {
let routes = routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.collect();
Router {
routes: Routes::RestXml(routes),
}
}
Routes::AwsJson10(routes) => {
let routes = routes
.into_iter()
.map(|(operation, route)| (operation, Layer::layer(&layer, route)))
.collect();
Router {
routes: Routes::AwsJson10(routes),
}
}
Routes::AwsJson11(routes) => {
let routes = routes
.into_iter()
.map(|(operation, route)| (operation, Layer::layer(&layer, route)))
.collect();
Router {
routes: Routes::AwsJson11(routes),
}
}
}
}
/// Create a new RestJson1 `Router` from an iterator over pairs of [`RequestSpec`]s and services.
///
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_rest_json_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
RequestSpec,
),
>,
{
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect();
Router { routes }
// Sort them once by specifity, with the more specific routes sorted before the less
// specific ones, so that when routing a request we can simply iterate through the routes
// and pick the first one that matches.
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self {
routes: Routes::RestJson1(routes),
}
}
/// Create a new RestXml `Router` from an iterator over pairs of [`RequestSpec`]s and services.
///
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_rest_xml_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
RequestSpec,
),
>,
{
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter()
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect();
// Sort them once by specifity, with the more specific routes sorted before the less
// specific ones, so that when routing a request we can simply iterate through the routes
// and pick the first one that matches.
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self {
routes: Routes::RestXml(routes),
}
}
/// Create a new AwsJson 1.0 `Router` from an iterator over pairs of operation names and services.
///
/// If the iterator is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_aws_json_10_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
String,
),
>,
{
let routes = routes
.into_iter()
.map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
.collect();
Self {
routes: Routes::AwsJson10(routes),
}
}
/// Create a new AwsJson 1.1 `Router` from a vector of pairs of operations and services.
///
/// If the vector is empty the router will respond `404 Not Found` to all requests.
#[doc(hidden)]
pub fn new_aws_json_11_router<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
String,
),
>,
{
let routes = routes
.into_iter()
.map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
.collect();
Self {
routes: Routes::AwsJson11(routes),
}
}
}
@ -148,44 +310,84 @@ where
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
let mut method_not_allowed = false;
match &self.routes {
// REST routes.
Routes::RestJson1(routes) | Routes::RestXml(routes) => {
let mut method_not_allowed = false;
for (route, request_spec) in &self.routes {
match request_spec.matches(&req) {
request_spec::Match::Yes => {
return RouterFuture::from_oneshot(route.clone().oneshot(req));
// Loop through all the routes and validate if any of them matches. Routes are already ranked.
for (route, request_spec) in routes {
match request_spec.matches(&req) {
request_spec::Match::Yes => {
return RouterFuture::from_oneshot(route.clone().oneshot(req));
}
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
// Continue looping to see if another route matches.
request_spec::Match::No => continue,
}
}
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
// Continue looping to see if another route matches.
request_spec::Match::No => continue,
if method_not_allowed {
// The HTTP method is not correct.
self.method_not_allowed()
} else {
// In any other case return the `RuntimeError::UnknownOperation`.
self.unknown_operation()
}
}
// AwsJson routes.
Routes::AwsJson10(routes) | Routes::AwsJson11(routes) => {
if req.uri() == "/" {
// Check the request method for POST.
if req.method() == http::Method::POST {
// Find the `x-amz-target` header.
if let Some(target) = req.headers().get("x-amz-target") {
if let Ok(target) = target.to_str() {
// Lookup in the `HashMap` for a route for the target.
let route = routes.get(target);
if let Some(route) = route {
return RouterFuture::from_oneshot(route.clone().oneshot(req));
}
}
}
} else {
// The HTTP method is not POST.
return self.method_not_allowed();
}
}
// In any other case return the `RuntimeError::UnknownOperation`.
self.unknown_operation()
}
}
let status_code = if method_not_allowed {
StatusCode::METHOD_NOT_ALLOWED
} else {
StatusCode::NOT_FOUND
};
RouterFuture::from_response(
Response::builder()
.status(status_code)
.body(crate::body::empty())
.unwrap(),
)
}
}
#[cfg(test)]
mod tests {
mod rest_tests {
use super::*;
use crate::{body::boxed, routing::request_spec::*};
use futures_util::Future;
use http::Method;
use http::{HeaderMap, Method};
use std::pin::Pin;
/// Helper function to build a `Request`. Used in other test modules.
pub fn req(method: &Method, uri: &str) -> Request<()> {
Request::builder().method(method).uri(uri).body(()).unwrap()
pub fn req(method: &Method, uri: &str, headers: Option<HeaderMap>) -> Request<()> {
let mut r = Request::builder().method(method).uri(uri).body(()).unwrap();
if let Some(headers) = headers {
*r.headers_mut() = headers
}
r
}
// Returns a `Response`'s body as a `String`, without consuming the response.
pub async fn get_body_as_string<B>(res: &mut Response<B>) -> String
where
B: http_body::Body + std::marker::Unpin,
B::Error: std::fmt::Debug,
{
let body_mut = res.body_mut();
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
String::from(std::str::from_utf8(&body_bytes).unwrap())
}
/// A service that returns its name and the request's URI in the response body.
@ -210,17 +412,6 @@ mod tests {
}
}
// Returns a `Response`'s body as a `String`, without consuming the response.
async fn get_body_as_str<B>(res: &mut Response<B>) -> String
where
B: http_body::Body + std::marker::Unpin,
B::Error: std::fmt::Debug,
{
let body_mut = res.body_mut();
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
String::from(std::str::from_utf8(&body_bytes).unwrap())
}
// This test is a rewrite of `mux.spec.ts`.
// https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts
#[tokio::test]
@ -271,55 +462,64 @@ mod tests {
),
];
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| {
// Test both RestJson1 and RestXml routers.
let router_json = Router::new_rest_json_router(request_specs.clone().into_iter().map(|(spec, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec,
)
}));
let router_xml = Router::new_rest_xml_router(request_specs.into_iter().map(|(spec, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec,
)
}));
let hits = vec![
("A", Method::GET, "/a/b/c"),
("MiddleGreedy", Method::GET, "/mg/a/z"),
("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux"),
("Delete", Method::DELETE, "/?foo=bar&baz"),
("Delete", Method::DELETE, "/?foo=bar&baz=&"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo="),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"),
];
for (svc_name, method, uri) in &hits {
let mut res = router.call(req(method, uri)).await.unwrap();
let actual_body = get_body_as_str(&mut res).await;
for mut router in [router_json, router_xml] {
let hits = vec![
("A", Method::GET, "/a/b/c"),
("MiddleGreedy", Method::GET, "/mg/a/z"),
("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux"),
("Delete", Method::DELETE, "/?foo=bar&baz"),
("Delete", Method::DELETE, "/?foo=bar&baz=&"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo="),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"),
];
for (svc_name, method, uri) in &hits {
let mut res = router.call(req(method, uri, None)).await.unwrap();
let actual_body = get_body_as_string(&mut res).await;
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
}
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
}
for (_, _, uri) in hits {
let res = router.call(req(&Method::PATCH, uri)).await.unwrap();
assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status());
}
for (_, _, uri) in hits {
let res = router.call(req(&Method::PATCH, uri, None)).await.unwrap();
assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status());
}
let misses = vec![
(Method::GET, "/a"),
(Method::GET, "/a/b"),
(Method::GET, "/mg"),
(Method::GET, "/mg/q"),
(Method::GET, "/mg/z"),
(Method::GET, "/mg/a/b/z/c"),
(Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?baz=quux"),
(Method::POST, "/query_key_only?baz=quux"),
(Method::GET, "/"),
(Method::POST, "/"),
];
for (method, miss) in misses {
let res = router.call(req(&method, miss)).await.unwrap();
assert_eq!(StatusCode::NOT_FOUND, res.status());
let misses = vec![
(Method::GET, "/a"),
(Method::GET, "/a/b"),
(Method::GET, "/mg"),
(Method::GET, "/mg/q"),
(Method::GET, "/mg/z"),
(Method::GET, "/mg/a/b/z/c"),
(Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?baz=quux"),
(Method::POST, "/query_key_only?baz=quux"),
(Method::GET, "/"),
(Method::POST, "/"),
];
for (method, miss) in misses {
let res = router.call(req(&method, miss, None)).await.unwrap();
assert_eq!(StatusCode::NOT_FOUND, res.status());
}
}
}
@ -364,7 +564,7 @@ mod tests {
),
];
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| {
let mut router = Router::new_rest_json_router(request_specs.into_iter().map(|(spec, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec,
@ -378,10 +578,96 @@ mod tests {
("B2", Method::GET, "/b/foo?q=baz"),
];
for (svc_name, method, uri) in &hits {
let mut res = router.call(req(method, uri)).await.unwrap();
let actual_body = get_body_as_str(&mut res).await;
let mut res = router.call(req(method, uri, None)).await.unwrap();
let actual_body = get_body_as_string(&mut res).await;
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
}
}
}
#[cfg(test)]
mod awsjson_tests {
use super::rest_tests::{get_body_as_string, req};
use super::*;
use crate::body::boxed;
use futures_util::Future;
use http::{HeaderMap, HeaderValue, Method};
use pretty_assertions::assert_eq;
use std::pin::Pin;
/// A service that returns its name and the request's URI in the response body.
#[derive(Clone)]
struct NamedEchoOperationService(String);
impl<B> Service<Request<B>> for NamedEchoOperationService {
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
let target = req
.headers()
.get("x-amz-target")
.map(|x| x.to_str().unwrap())
.unwrap_or("unknown");
let body = boxed(Body::from(format!("{} :: {}", self.0, target)));
let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) };
Box::pin(fut)
}
}
#[tokio::test]
async fn simple_routing() {
let routes = vec![("Service.Operation", "A")];
let router_json10 = Router::new_aws_json_10_router(routes.clone().into_iter().map(|(operation, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))),
operation.to_string(),
)
}));
let router_json11 = Router::new_aws_json_11_router(routes.into_iter().map(|(operation, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))),
operation.to_string(),
)
}));
for mut router in [router_json10, router_json11] {
let mut headers = HeaderMap::new();
headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation"));
// Valid request, should return a valid body.
let mut res = router
.call(req(&Method::POST, "/", Some(headers.clone())))
.await
.unwrap();
let actual_body = get_body_as_string(&mut res).await;
assert_eq!(format!("{} :: {}", "A", "Service.Operation"), actual_body);
// No headers, should return NOT_FOUND.
let res = router.call(req(&Method::POST, "/", None)).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
// Wrong HTTP method, should return METHOD_NOT_ALLOWED.
let res = router
.call(req(&Method::GET, "/", Some(headers.clone())))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
// Wrong URI, should return NOT_FOUND.
let res = router
.call(req(&Method::POST, "/something", Some(headers)))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
}
}

View File

@ -243,7 +243,7 @@ impl RequestSpec {
#[cfg(test)]
mod tests {
use super::super::tests::req;
use super::super::rest_tests::req;
use super::*;
use http::Method;
@ -295,7 +295,7 @@ mod tests {
(Method::GET, "/mg/a/z/z/z"),
];
for (method, uri) in &hits {
assert_eq!(Match::Yes, spec.matches(&req(method, uri)));
assert_eq!(Match::Yes, spec.matches(&req(method, uri, None)));
}
}
@ -309,7 +309,7 @@ mod tests {
(Method::DELETE, "/?foo&foo"),
];
for (method, uri) in &hits {
assert_eq!(Match::Yes, spec.matches(&req(method, uri)));
assert_eq!(Match::Yes, spec.matches(&req(method, uri, None)));
}
}
@ -325,7 +325,7 @@ mod tests {
fn repeated_query_keys_same_values_match() {
assert_eq!(
Match::Yes,
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar"))
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar", None))
);
}
@ -333,7 +333,7 @@ mod tests {
fn repeated_query_keys_distinct_values_does_not_match() {
assert_eq!(
Match::No,
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz"))
key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz", None))
);
}
@ -354,11 +354,11 @@ mod tests {
#[test]
fn empty_segments_in_the_middle_do_matter() {
assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b")));
assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b", None)));
let misses = vec![(Method::GET, "/a//b"), (Method::GET, "//////a//b")];
for (method, uri) in &misses {
assert_eq!(Match::No, ab_spec().matches(&req(method, uri)));
assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None)));
}
}
@ -379,10 +379,10 @@ mod tests {
(Method::GET, "/a//b"), // Label is bound to `""`.
];
for (method, uri) in &hits {
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri)));
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None)));
}
assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b")));
assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b", None)));
}
#[test]
@ -403,7 +403,7 @@ mod tests {
(Method::GET, "/a///a//b///suffix"),
];
for (method, uri) in &hits {
assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri)));
assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri, None)));
}
}
@ -418,7 +418,7 @@ mod tests {
(Method::GET, "//a//b////"),
];
for (method, uri) in &misses {
assert_eq!(Match::No, ab_spec().matches(&req(method, uri)));
assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None)));
}
}
@ -432,13 +432,13 @@ mod tests {
let misses = vec![(Method::GET, "/a"), (Method::GET, "/a//"), (Method::GET, "/a///")];
for (method, uri) in &misses {
assert_eq!(Match::No, label_spec.matches(&req(method, uri)));
assert_eq!(Match::No, label_spec.matches(&req(method, uri, None)));
}
// In the second example, the label is bound to `""`.
let hits = vec![(Method::GET, "/a/label"), (Method::GET, "/a/")];
for (method, uri) in &hits {
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri)));
assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None)));
}
}
}

View File

@ -25,7 +25,8 @@ use crate::protocols::Protocol;
#[derive(Debug)]
pub enum RuntimeErrorKind {
// UnknownOperation,
/// The requested operation does not exist.
UnknownOperation,
/// Request failed to deserialize or response failed to serialize.
Serialization(crate::Error),
/// As of writing, this variant can only occur upon failure to extract an
@ -43,6 +44,7 @@ impl RuntimeErrorKind {
match self {
RuntimeErrorKind::Serialization(_) => "SerializationException",
RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
RuntimeErrorKind::UnknownOperation => "UnknownOperation",
}
}
}
@ -58,11 +60,16 @@ impl axum_core::response::IntoResponse for RuntimeError {
let status_code = match self.kind {
RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
RuntimeErrorKind::UnknownOperation => http::StatusCode::NOT_FOUND,
};
let body = crate::body::to_boxed(match self.protocol {
Protocol::RestJson1 => "{}",
Protocol::RestXml => "",
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
Protocol::AwsJson10 => "",
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
Protocol::AwsJson11 => "",
});
let mut builder = http::Response::builder();
@ -74,9 +81,9 @@ impl axum_core::response::IntoResponse for RuntimeError {
.header("Content-Type", "application/json")
.header("X-Amzn-Errortype", self.kind.name());
}
Protocol::RestXml => {
builder = builder.header("Content-Type", "application/xml");
}
Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"),
Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"),
Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"),
}
builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from(