Comply with content types for the accept header (#1424)

* Verify accept header for payloads with `@mediaType`

Comply with content-types as described in the documentation [0]

[0] https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html#content-type

Signed-off-by: Daniele Ahmed <ahmeddan@amazon.com>
This commit is contained in:
82marbag 2022-06-13 05:17:02 -04:00 committed by GitHub
parent b7506ec2e5
commit c78c67e3fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 9 deletions

View File

@ -661,9 +661,7 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", TestType.Request),
FailingTest(RestJson, "RestJsonHttpResponseCodeDefaultsToModeledCode", TestType.Response),
FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonAccept", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedAccept", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledAccept", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyByteMalformedValueRejected_case2", TestType.MalformedRequest),

View File

@ -159,6 +159,26 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
val verifyResponseContentType = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
if let Some(headers) = req.headers() {
if let Some(accept) = headers.get(#{http}::header::ACCEPT) {
if accept != "$contentType" {
return Err(Self::Rejection {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable,
})
}
}
}
""",
*codegenScope,
)
}
}
// Implement `FromRequest` trait for input types.
rustTemplate(
"""
@ -173,6 +193,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
{
type Rejection = #{RuntimeError};
async fn from_request(req: &mut #{SmithyHttpServer}::request::RequestParts<B>) -> Result<Self, Self::Rejection> {
#{verify_response_content_type:W}
#{parse_request}(req)
.await
.map($inputName)
@ -187,7 +208,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
""".trimIndent(),
*codegenScope,
"I" to inputSymbol,
"parse_request" to serverParseRequest(operationShape)
"parse_request" to serverParseRequest(operationShape),
"verify_response_content_type" to verifyResponseContentType,
)
// Implement `IntoResponse` for output types.
@ -227,7 +249,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
}
}
}
""".trimIndent()
"""
rustTemplate(
"""

View File

@ -9,7 +9,10 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
@ -23,6 +26,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.outputShape
class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
override fun protocol(codegenContext: CodegenContext): Protocol = RestJson(codegenContext)
@ -56,15 +61,29 @@ class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
* `application/json` if not overridden.
*/
class RestJsonHttpBindingResolver(
model: Model,
private val model: Model,
contentTypes: ProtocolContentTypes,
) : HttpTraitHttpBindingResolver(model, contentTypes) {
/**
* In the RestJson1 protocol, HTTP responses have a default `Content-Type: application/json` header if it is not
* overridden by a specific mechanism e.g. an output shape member is targeted with `httpPayload` or `mediaType` traits.
*/
override fun responseContentType(operationShape: OperationShape): String =
super.responseContentType(operationShape) ?: "application/json"
override fun responseContentType(operationShape: OperationShape): String? {
val members = operationShape
.outputShape(model)
.members()
// TODO(https://github.com/awslabs/smithy/issues/1259)
// Temporary fix for https://github.com/awslabs/smithy/blob/df456a514f72f4e35f0fb07c7e26006ff03b2071/smithy-model/src/main/java/software/amazon/smithy/model/knowledge/HttpBindingIndex.java#L352
for (member in members) {
if (member.hasTrait<HttpPayloadTrait>()) {
val target = model.expectShape(member.target)
if (!target.hasTrait<StreamingTrait>() && !target.hasTrait<MediaTypeTrait>() && target.isBlobShape) {
return null
}
}
}
return super.responseContentType(operationShape) ?: "application/json"
}
}
class RestJson(private val codegenContext: CodegenContext) : Protocol {

View File

@ -36,7 +36,7 @@ pub enum RuntimeErrorKind {
/// [`crate::extension::Extension`] from the request.
InternalFailure(crate::Error),
// UnsupportedMediaType,
// NotAcceptable,
NotAcceptable,
}
/// String representation of the runtime error type.
@ -47,7 +47,8 @@ impl RuntimeErrorKind {
match self {
RuntimeErrorKind::Serialization(_) => "SerializationException",
RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
RuntimeErrorKind::UnknownOperation => "UnknownOperation",
RuntimeErrorKind::UnknownOperation => "UnknownOperationException",
RuntimeErrorKind::NotAcceptable => "NotAcceptableException",
}
}
}
@ -64,6 +65,7 @@ impl IntoResponse for RuntimeError {
RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
RuntimeErrorKind::UnknownOperation => http::StatusCode::NOT_FOUND,
RuntimeErrorKind::NotAcceptable => http::StatusCode::NOT_ACCEPTABLE,
};
let body = crate::body::to_boxed(match self.protocol {