Remove Protocol enum (#1829)

* Remove Protocol enum

* Update Python implementation
This commit is contained in:
Harry Barber 2022-10-10 14:29:37 +01:00 committed by GitHub
parent 238cf8b434
commit 78022d69ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 236 additions and 270 deletions

View File

@ -137,3 +137,15 @@ message = "Fix regression where `connect_timeout` and `read_timeout` fields are
references = ["smithy-rs#1822"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "kevinpark1217"
[[smithy-rs]]
message = "Remove `Protocol` enum, removing an obstruction to extending smithy to third-party protocols."
references = ["smithy-rs#1829"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" }
author = "hlbarber"
[[smithy-rs]]
message = "Convert the `protocol` argument on `PyMiddlewares::new` constructor to a type parameter."
references = ["smithy-rs#1829"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" }
author = "hlbarber"

View File

@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
/**
* Generates a Python compatible application and server that can be configured from Python.
@ -62,13 +63,13 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
*/
class PythonApplicationGenerator(
codegenContext: CodegenContext,
private val protocol: ServerProtocol,
private val operations: List<OperationShape>,
) {
private val symbolProvider = codegenContext.symbolProvider
private val libName = "lib${codegenContext.settings.moduleName.toSnakeCase()}"
private val runtimeConfig = codegenContext.runtimeConfig
private val model = codegenContext.model
private val protocol = codegenContext.protocol
private val codegenScope =
arrayOf(
"SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(),
@ -88,6 +89,7 @@ class PythonApplicationGenerator(
fun render(writer: RustWriter) {
renderPyApplicationRustDocs(writer)
renderAppStruct(writer)
renderAppDefault(writer)
renderAppClone(writer)
renderPyAppTrait(writer)
renderAppImpl(writer)
@ -98,7 +100,7 @@ class PythonApplicationGenerator(
writer.rustTemplate(
"""
##[#{pyo3}::pyclass]
##[derive(Debug, Default)]
##[derive(Debug)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
@ -128,6 +130,25 @@ class PythonApplicationGenerator(
)
}
private fun renderAppDefault(writer: RustWriter) {
writer.rustTemplate(
"""
impl Default for App {
fn default() -> Self {
Self {
handlers: Default::default(),
middlewares: #{SmithyPython}::PyMiddlewares::new::<#{Protocol}>(vec![]),
context: None,
workers: #{parking_lot}::Mutex::new(vec![]),
}
}
}
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}
private fun renderAppImpl(writer: RustWriter) {
writer.rustBlockTemplate(
"""
@ -165,13 +186,9 @@ class PythonApplicationGenerator(
rustTemplate(
"""
let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop);
use #{SmithyPython}::PyApp;
let service = #{tower}::ServiceBuilder::new().layer(
#{SmithyPython}::PyMiddlewareLayer::new(
self.middlewares.clone(),
self.protocol(),
middleware_locals
)?,
let service = #{tower}::ServiceBuilder::new()
.layer(
#{SmithyPython}::PyMiddlewareLayer::<#{Protocol}>::new(self.middlewares.clone(), middleware_locals),
);
let router: #{SmithyServer}::routing::Router = router
.build()
@ -179,6 +196,7 @@ class PythonApplicationGenerator(
.into();
Ok(router.layer(service))
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}
@ -186,7 +204,6 @@ class PythonApplicationGenerator(
}
private fun renderPyAppTrait(writer: RustWriter) {
val protocol = protocol.toString().replace("#", "##")
writer.rustTemplate(
"""
impl #{SmithyPython}::PyApp for App {
@ -202,9 +219,6 @@ class PythonApplicationGenerator(
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
fn protocol(&self) -> &'static str {
"$protocol"
}
}
""",
*codegenScope,

View File

@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationHandlerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
/**
* The Rust code responsible to run the Python business logic on the Python interpreter
@ -33,8 +34,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperat
*/
class PythonServerOperationHandlerGenerator(
codegenContext: CodegenContext,
protocol: ServerProtocol,
private val operations: List<OperationShape>,
) : ServerOperationHandlerGenerator(codegenContext, operations) {
) : ServerOperationHandlerGenerator(codegenContext, protocol, operations) {
private val symbolProvider = codegenContext.symbolProvider
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope =

View File

@ -34,12 +34,12 @@ class PythonServerServiceGenerator(
}
override fun renderOperationHandler(writer: RustWriter, operations: List<OperationShape>) {
PythonServerOperationHandlerGenerator(context, operations).render(writer)
PythonServerOperationHandlerGenerator(context, protocol, operations).render(writer)
}
override fun renderExtras(operations: List<OperationShape>) {
rustCrate.withModule(RustModule.public("python_server_application", "Python server and application implementation.")) { writer ->
PythonApplicationGenerator(context, operations)
PythonApplicationGenerator(context, protocol, operations)
.render(writer)
}
}

View File

@ -19,9 +19,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErr
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator
/**
@ -29,12 +29,11 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBou
*/
open class ServerOperationHandlerGenerator(
codegenContext: CodegenContext,
val protocol: ServerProtocol,
private val operations: List<OperationShape>,
) {
private val serverCrate = "aws_smithy_http_server"
private val service = codegenContext.serviceShape
private val model = codegenContext.model
private val protocol = codegenContext.protocol
private val symbolProvider = codegenContext.symbolProvider
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf(
@ -83,11 +82,8 @@ open class ServerOperationHandlerGenerator(
Ok(v) => v,
Err(extension_not_found_rejection) => {
let extension = $serverCrate::extension::RuntimeErrorExtension::new(extension_not_found_rejection.to_string());
let runtime_error = $serverCrate::runtime_error::RuntimeError {
protocol: #{SmithyHttpServer}::protocols::Protocol::${protocol.name.toPascalCase()},
kind: extension_not_found_rejection.into(),
};
let mut response = runtime_error.into_response();
let runtime_error = $serverCrate::runtime_error::RuntimeError::from(extension_not_found_rejection);
let mut response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error);
response.extensions_mut().insert(extension);
return response.map($serverCrate::body::boxed);
}
@ -109,7 +105,8 @@ open class ServerOperationHandlerGenerator(
let input_wrapper = match $inputWrapperName::from_request(&mut req).await {
Ok(v) => v,
Err(runtime_error) => {
return runtime_error.into_response().map($serverCrate::body::boxed);
let response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error);
return response.map($serverCrate::body::boxed);
}
};
$callImpl
@ -120,6 +117,7 @@ open class ServerOperationHandlerGenerator(
response.map(#{SmithyHttpServer}::body::boxed)
}
""",
"Protocol" to protocol.markerStruct(),
*codegenScope,
)
}

View File

@ -29,7 +29,7 @@ open class ServerServiceGenerator(
private val rustCrate: RustCrate,
private val protocolGenerator: ServerProtocolGenerator,
private val protocolSupport: ProtocolSupport,
private val protocol: ServerProtocol,
val protocol: ServerProtocol,
private val codegenContext: CodegenContext,
) {
private val index = TopDownIndex.of(codegenContext.model)
@ -107,7 +107,7 @@ open class ServerServiceGenerator(
// Render operations handler.
open fun renderOperationHandler(writer: RustWriter, operations: List<OperationShape>) {
ServerOperationHandlerGenerator(codegenContext, operations).render(writer)
ServerOperationHandlerGenerator(codegenContext, protocol, operations).render(writer)
}
// Render operations registry.

View File

@ -11,11 +11,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
open class ServerProtocolGenerator(
codegenContext: CodegenContext,
protocol: Protocol,
val protocol: ServerProtocol,
makeOperationGenerator: MakeOperationGenerator,
private val traitGenerator: ProtocolTraitImplGenerator,
) : ProtocolGenerator(codegenContext, protocol, makeOperationGenerator, traitGenerator) {

View File

@ -452,8 +452,9 @@ class ServerProtocolTestGenerator(
"""
let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request);
let rejection = super::$operationName::from_request(&mut http_request).await.expect_err("request was accepted but we expected it to be rejected");
let http_response = rejection.into_response();
let http_response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(rejection);
""",
"Protocol" to protocolGenerator.protocol.markerStruct(),
*codegenScope,
)
checkResponse(this, testCase.response)

View File

@ -70,7 +70,6 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
@ -177,10 +176,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
rustTemplate(
"""
if ! #{SmithyHttpServer}::protocols::accept_header_classifier(req, ${contentType.dq()}) {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable,
})
return Err(#{RuntimeError}::NotAcceptable)
}
""",
*codegenScope,
@ -200,10 +196,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
rustTemplate(
"""
if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType,
})
return Err(#{RuntimeError}::UnsupportedMediaType)
}
""",
*codegenScope,
@ -230,12 +223,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
#{parse_request}(req)
.await
.map($inputName)
.map_err(
|err| #{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: err.into()
}
)
.map_err(Into::into)
}
}
@ -282,12 +270,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
Self::Output(o) => {
match #{serialize_response}(o) {
Ok(response) => response,
Err(e) => {
#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: e.into()
}.into_response()
}
Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e))
}
},
Self::Error(err) => {
@ -296,12 +279,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(err.name()));
response
},
Err(e) => {
#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: e.into()
}.into_response()
}
Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e))
}
}
}
@ -346,12 +324,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
"""
match #{serialize_response}(self.0) {
Ok(response) => response,
Err(e) => {
#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: e.into()
}.into_response()
}
Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e))
}
""".trimIndent()

View File

@ -5,8 +5,14 @@
//! Python error definition.
use aws_smithy_http_server::protocols::Protocol;
use aws_smithy_http_server::{body::to_boxed, response::Response};
use aws_smithy_http_server::{
body::{to_boxed, BoxBody},
proto::{
aws_json_10::AwsJson10, aws_json_11::AwsJson11, rest_json_1::AwsRestJson1,
rest_xml::AwsRestXml,
},
response::IntoResponse,
};
use aws_smithy_types::date_time::{ConversionError, DateTimeParseError};
use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr};
use thiserror::Error;
@ -62,39 +68,50 @@ impl From<PyErr> for PyMiddlewareException {
}
}
impl PyMiddlewareException {
/// Convert the exception into a [Response], following the [Protocol] specification.
pub(crate) fn into_response(self, protocol: Protocol) -> Response {
let body = to_boxed(match protocol {
Protocol::RestJson1 => self.json_body(),
Protocol::RestXml => self.xml_body(),
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
Protocol::AwsJson10 => self.json_body(),
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
Protocol::AwsJson11 => self.json_body(),
});
let mut builder = http::Response::builder();
builder = builder.status(self.status_code);
match protocol {
Protocol::RestJson1 => {
builder = builder
impl IntoResponse<AwsRestJson1> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/json")
.header("X-Amzn-Errortype", "MiddlewareException");
}
Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"),
Protocol::AwsJson10 => {
builder = builder.header("Content-Type", "application/x-amz-json-1.0")
}
Protocol::AwsJson11 => {
builder = builder.header("Content-Type", "application/x-amz-json-1.1")
}
.header("X-Amzn-Errortype", "MiddlewareException")
.body(to_boxed(self.json_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
builder.body(body).expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
impl IntoResponse<AwsRestXml> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/xml")
.body(to_boxed(self.xml_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl IntoResponse<AwsJson10> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/x-amz-json-1.0")
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
.body(to_boxed(self.json_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl IntoResponse<AwsJson11> for PyMiddlewareException {
fn into_response(self) -> http::Response<BoxBody> {
http::Response::builder()
.status(self.status_code)
.header("Content-Type", "application/x-amz-json-1.1")
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
.body(to_boxed(self.json_body()))
.expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl PyMiddlewareException {
/// Serialize the body into a JSON object.
fn json_body(&self) -> String {
let mut out = String::new();

View File

@ -4,11 +4,10 @@
*/
//! Execute Python middleware handlers.
use aws_smithy_http_server::body::Body;
use aws_smithy_http_server::{body::Body, body::BoxBody, response::IntoResponse};
use http::Request;
use pyo3::prelude::*;
use aws_smithy_http_server::protocols::Protocol;
use pyo3_asyncio::TaskLocals;
use crate::{PyMiddlewareException, PyRequest, PyResponse};
@ -36,18 +35,27 @@ pub struct PyMiddlewareHandler {
/// Structure holding the list of Python middlewares that will be executed by this server.
///
/// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer.
#[derive(Debug, Clone, Default)]
pub struct PyMiddlewares(Vec<PyMiddlewareHandler>);
#[derive(Debug, Clone)]
pub struct PyMiddlewares {
handlers: Vec<PyMiddlewareHandler>,
into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
}
impl PyMiddlewares {
/// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers.
pub fn new(handlers: Vec<PyMiddlewareHandler>) -> Self {
Self(handlers)
pub fn new<P>(handlers: Vec<PyMiddlewareHandler>) -> Self
where
PyMiddlewareException: IntoResponse<P>,
{
Self {
handlers,
into_response: PyMiddlewareException::into_response,
}
}
/// Add a new handler to the list.
pub fn push(&mut self, handler: PyMiddlewareHandler) {
self.0.push(handler);
self.handlers.push(handler);
}
/// Execute a single middleware handler.
@ -114,13 +122,9 @@ impl PyMiddlewares {
/// and return a protocol specific error, with the option of setting the HTTP return code.
/// * Middleware raising any other exception will immediately terminate the request handling and
/// return a protocol specific error, with HTTP status code 500.
pub fn run(
&mut self,
mut request: Request<Body>,
protocol: Protocol,
locals: TaskLocals,
) -> PyFuture {
let handlers = self.0.clone();
pub fn run(&mut self, mut request: Request<Body>, locals: TaskLocals) -> PyFuture {
let handlers = self.handlers.clone();
let into_response = self.into_response;
// Run all Python handlers in a loop.
Box::pin(async move {
tracing::debug!("Executing Python middleware stack");
@ -152,7 +156,7 @@ impl PyMiddlewares {
tracing::debug!(
"Middleware `{name}` returned an error, exit middleware loop"
);
return Err(e.into_response(protocol));
return Err((into_response)(e));
}
}
}
@ -166,6 +170,7 @@ impl PyMiddlewares {
#[cfg(test)]
mod tests {
use aws_smithy_http_server::proto::rest_json_1::AwsRestJson1;
use http::HeaderValue;
use hyper::body::to_bytes;
use pretty_assertions::assert_eq;
@ -175,7 +180,7 @@ mod tests {
#[tokio::test]
async fn request_middleware_chain_keeps_headers_changes() -> PyResult<()> {
let locals = crate::tests::initialize();
let mut middlewares = PyMiddlewares(vec![]);
let mut middlewares = PyMiddlewares::new::<AwsRestJson1>(vec![]);
Python::with_gil(|py| {
let middleware = PyModule::new(py, "middleware").unwrap();
@ -212,11 +217,7 @@ def second_middleware(request: Request):
})?;
let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap();
assert_eq!(
@ -229,7 +230,7 @@ def second_middleware(request: Request):
#[tokio::test]
async fn request_middleware_return_response() -> PyResult<()> {
let locals = crate::tests::initialize();
let mut middlewares = PyMiddlewares(vec![]);
let mut middlewares = PyMiddlewares::new::<AwsRestJson1>(vec![]);
Python::with_gil(|py| {
let middleware = PyModule::new(py, "middleware").unwrap();
@ -252,11 +253,7 @@ def middleware(request: Request):
})?;
let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap_err();
assert_eq!(result.status(), 200);
@ -268,7 +265,7 @@ def middleware(request: Request):
#[tokio::test]
async fn request_middleware_raise_middleware_exception() -> PyResult<()> {
let locals = crate::tests::initialize();
let mut middlewares = PyMiddlewares(vec![]);
let mut middlewares = PyMiddlewares::new::<AwsRestJson1>(vec![]);
Python::with_gil(|py| {
let middleware = PyModule::new(py, "middleware").unwrap();
@ -291,11 +288,7 @@ def middleware(request: Request):
})?;
let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap_err();
assert_eq!(result.status(), 503);
@ -311,7 +304,7 @@ def middleware(request: Request):
#[tokio::test]
async fn request_middleware_raise_python_exception() -> PyResult<()> {
let locals = crate::tests::initialize();
let mut middlewares = PyMiddlewares(vec![]);
let mut middlewares = PyMiddlewares::new::<AwsRestJson1>(vec![]);
Python::with_gil(|py| {
let middleware = PyModule::from_code(
@ -333,11 +326,7 @@ def middleware(request):
})?;
let result = middlewares
.run(
Request::builder().body(Body::from("")).unwrap(),
Protocol::RestJson1,
locals,
)
.run(Request::builder().body(Body::from("")).unwrap(), locals)
.await
.unwrap_err();
assert_eq!(result.status(), 500);

View File

@ -5,69 +5,52 @@
//! Tower layer implementation of Python middleware handling.
use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use aws_smithy_http_server::{
body::{Body, BoxBody},
protocols::Protocol,
response::IntoResponse,
};
use futures::{ready, Future};
use http::{Request, Response};
use pin_project_lite::pin_project;
use pyo3::PyResult;
use pyo3_asyncio::TaskLocals;
use tower::{Layer, Service};
use crate::{error::PyException, middleware::PyFuture, PyMiddlewares};
use crate::{middleware::PyFuture, PyMiddlewareException, PyMiddlewares};
/// Tower [Layer] implementation of Python middleware handling.
///
/// Middleware stored in the `handlers` attribute will be executed, in order,
/// inside an async Tower middleware.
#[derive(Debug, Clone)]
pub struct PyMiddlewareLayer {
pub struct PyMiddlewareLayer<P> {
handlers: PyMiddlewares,
protocol: Protocol,
locals: TaskLocals,
_protocol: PhantomData<P>,
}
impl PyMiddlewareLayer {
pub fn new(
handlers: PyMiddlewares,
protocol: &str,
locals: TaskLocals,
) -> PyResult<PyMiddlewareLayer> {
let protocol = match protocol {
"aws.protocols#restJson1" => Protocol::RestJson1,
"aws.protocols#restXml" => Protocol::RestXml,
"aws.protocols#awsjson10" => Protocol::AwsJson10,
"aws.protocols#awsjson11" => Protocol::AwsJson11,
_ => {
return Err(PyException::new_err(format!(
"Protocol {protocol} is not supported"
)))
}
};
Ok(Self {
impl<P> PyMiddlewareLayer<P> {
pub fn new(handlers: PyMiddlewares, locals: TaskLocals) -> Self {
Self {
handlers,
protocol,
locals,
})
_protocol: PhantomData,
}
}
}
impl<S> Layer<S> for PyMiddlewareLayer {
impl<S, P> Layer<S> for PyMiddlewareLayer<P>
where
PyMiddlewareException: IntoResponse<P>,
{
type Service = PyMiddlewareService<S>;
fn layer(&self, inner: S) -> Self::Service {
PyMiddlewareService::new(
inner,
self.handlers.clone(),
self.protocol,
self.locals.clone(),
)
PyMiddlewareService::new(inner, self.handlers.clone(), self.locals.clone())
}
}
@ -76,21 +59,14 @@ impl<S> Layer<S> for PyMiddlewareLayer {
pub struct PyMiddlewareService<S> {
inner: S,
handlers: PyMiddlewares,
protocol: Protocol,
locals: TaskLocals,
}
impl<S> PyMiddlewareService<S> {
pub fn new(
inner: S,
handlers: PyMiddlewares,
protocol: Protocol,
locals: TaskLocals,
) -> PyMiddlewareService<S> {
pub fn new(inner: S, handlers: PyMiddlewares, locals: TaskLocals) -> PyMiddlewareService<S> {
Self {
inner,
handlers,
protocol,
locals,
}
}
@ -113,7 +89,7 @@ where
let clone = self.inner.clone();
// See https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
let inner = std::mem::replace(&mut self.inner, clone);
let run = self.handlers.run(req, self.protocol, self.locals.clone());
let run = self.handlers.run(req, self.locals.clone());
ResponseFuture {
middleware: State::Running { run },
@ -184,6 +160,7 @@ mod tests {
use super::*;
use aws_smithy_http_server::body::to_boxed;
use aws_smithy_http_server::proto::rest_json_1::AwsRestJson1;
use pyo3::prelude::*;
use tower::{Service, ServiceBuilder, ServiceExt};
@ -197,7 +174,7 @@ mod tests {
#[tokio::test]
async fn request_middlewares_are_chained_inside_layer() -> PyResult<()> {
let locals = crate::tests::initialize();
let mut middlewares = PyMiddlewares::new(vec![]);
let mut middlewares = PyMiddlewares::new::<AwsRestJson1>(vec![]);
Python::with_gil(|py| {
let middleware = PyModule::new(py, "middleware").unwrap();
@ -234,11 +211,7 @@ def second_middleware(request: Request):
})?;
let mut service = ServiceBuilder::new()
.layer(PyMiddlewareLayer::new(
middlewares,
"aws.protocols#restJson1",
locals,
)?)
.layer(PyMiddlewareLayer::<AwsRestJson1>::new(middlewares, locals))
.service_fn(echo);
let request = Request::get("/").body(Body::empty()).unwrap();

View File

@ -63,8 +63,6 @@ pub trait PyApp: Clone + pyo3::IntoPy<PyObject> {
fn middlewares(&mut self) -> &mut PyMiddlewares;
fn protocol(&self) -> &'static str;
/// Handle the graceful termination of Python workers by looping through all the
/// active workers and calling `terminate()` on them. If termination fails, this
/// method will try to `kill()` any failed worker.
@ -385,7 +383,6 @@ event_loop.add_signal_handler(signal.SIGINT,
/// fn context(&self) -> &Option<PyObject> { todo!() }
/// fn handlers(&mut self) -> &mut HashMap<String, PyHandler> { todo!() }
/// fn middlewares(&mut self) -> &mut PyMiddlewares { todo!() }
/// fn protocol(&self) -> &'static str { "proto1" }
/// }
///
/// #[pymethods]

View File

@ -131,7 +131,7 @@ impl Deref for ModeledErrorExtension {
}
/// Extension type used to store the _name_ of the [`crate::runtime_error::RuntimeError`] that
/// occurred during request handling (see [`crate::runtime_error::RuntimeErrorKind::name`]).
/// occurred during request handling (see [`crate::runtime_error::RuntimeError::name`]).
/// These are _unmodeled_ errors; the operation handler was not invoked.
#[derive(Debug, Clone)]
pub struct RuntimeErrorExtension(String);

View File

@ -7,15 +7,6 @@
use crate::rejection::MissingContentTypeReason;
use crate::request::RequestParts;
/// Supported protocols.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Protocol {
RestJson1,
RestXml,
AwsJson10,
AwsJson11,
}
/// When there are no modeled inputs,
/// a request body is empty and the content-type request header must not be set
pub fn content_type_header_empty_body_no_modeled_input<B>(

View File

@ -11,7 +11,7 @@
//! the framework, `RuntimeError` is surfaced to clients in HTTP responses: indeed, it implements
//! [`RuntimeError::into_response`]. Rejections can be "grouped" and converted into a
//! specific `RuntimeError` kind: for example, all request rejections due to serialization issues
//! can be conflated under the [`RuntimeErrorKind::Serialization`] enum variant.
//! can be conflated under the [`RuntimeError::Serialization`] enum variant.
//!
//! The HTTP response representation of the specific `RuntimeError` can be protocol-specific: for
//! example, the runtime error in the RestJson1 protocol sets the `X-Amzn-Errortype` header.
@ -21,15 +21,17 @@
//! and converts into the corresponding `RuntimeError`, and then it uses the its
//! [`RuntimeError::into_response`] method to render and send a response.
use http::StatusCode;
use crate::extension::RuntimeErrorExtension;
use crate::proto::aws_json_10::AwsJson10;
use crate::proto::aws_json_11::AwsJson11;
use crate::proto::rest_json_1::AwsRestJson1;
use crate::proto::rest_xml::AwsRestXml;
use crate::protocols::Protocol;
use crate::response::{IntoResponse, Response};
use crate::response::IntoResponse;
#[derive(Debug)]
pub enum RuntimeErrorKind {
pub enum RuntimeError {
/// Request failed to deserialize or response failed to serialize.
Serialization(crate::Error),
/// As of writing, this variant can only occur upon failure to extract an
@ -43,13 +45,22 @@ pub enum RuntimeErrorKind {
/// String representation of the runtime error type.
/// Used as the value of the `X-Amzn-Errortype` header in RestJson1.
/// Used as the value passed to construct an [`crate::extension::RuntimeErrorExtension`].
impl RuntimeErrorKind {
impl RuntimeError {
pub fn name(&self) -> &'static str {
match self {
RuntimeErrorKind::Serialization(_) => "SerializationException",
RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
RuntimeErrorKind::NotAcceptable => "NotAcceptableException",
RuntimeErrorKind::UnsupportedMediaType => "UnsupportedMediaTypeException",
Self::Serialization(_) => "SerializationException",
Self::InternalFailure(_) => "InternalFailureException",
Self::NotAcceptable => "NotAcceptableException",
Self::UnsupportedMediaType => "UnsupportedMediaTypeException",
}
}
pub fn status_code(&self) -> StatusCode {
match self {
Self::Serialization(_) => StatusCode::BAD_REQUEST,
Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE,
Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE,
}
}
}
@ -58,104 +69,93 @@ pub struct InternalFailureException;
impl IntoResponse<AwsJson10> for InternalFailureException {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
RuntimeError::internal_failure_from_protocol(Protocol::AwsJson10).into_response()
IntoResponse::<AwsJson10>::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new())))
}
}
impl IntoResponse<AwsJson11> for InternalFailureException {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
RuntimeError::internal_failure_from_protocol(Protocol::AwsJson11).into_response()
IntoResponse::<AwsJson11>::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new())))
}
}
impl IntoResponse<AwsRestJson1> for InternalFailureException {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
RuntimeError::internal_failure_from_protocol(Protocol::RestJson1).into_response()
IntoResponse::<AwsRestJson1>::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new())))
}
}
impl IntoResponse<AwsRestXml> for InternalFailureException {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
RuntimeError::internal_failure_from_protocol(Protocol::RestXml).into_response()
IntoResponse::<AwsRestXml>::into_response(RuntimeError::InternalFailure(crate::Error::new(String::new())))
}
}
#[derive(Debug)]
pub struct RuntimeError {
pub protocol: Protocol,
pub kind: RuntimeErrorKind,
}
impl<P> IntoResponse<P> for RuntimeError {
impl IntoResponse<AwsRestJson1> for RuntimeError {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
self.into_response()
}
}
impl RuntimeError {
pub fn internal_failure_from_protocol(protocol: Protocol) -> Self {
RuntimeError {
protocol,
kind: RuntimeErrorKind::InternalFailure(crate::Error::new(String::new())),
}
}
pub fn into_response(self) -> Response {
let status_code = match self.kind {
RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
RuntimeErrorKind::NotAcceptable => http::StatusCode::NOT_ACCEPTABLE,
RuntimeErrorKind::UnsupportedMediaType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE,
};
let body = crate::body::to_boxed(match self.protocol {
Protocol::RestJson1 => "{}",
Protocol::RestXml => "",
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
Protocol::AwsJson10 => "",
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
Protocol::AwsJson11 => "",
});
let mut builder = http::Response::builder();
builder = builder.status(status_code);
match self.protocol {
Protocol::RestJson1 => {
builder = builder
http::Response::builder()
.status(self.status_code())
.header("Content-Type", "application/json")
.header("X-Amzn-Errortype", self.kind.name());
}
Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"),
Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"),
Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"),
}
builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from(
self.kind.name(),
)));
builder.body(body).expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
.header("X-Amzn-Errortype", self.name())
.extension(RuntimeErrorExtension::new(self.name().to_string()))
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
.body(crate::body::to_boxed("{}"))
.expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl From<crate::rejection::RequestExtensionNotFoundRejection> for RuntimeErrorKind {
impl IntoResponse<AwsRestXml> for RuntimeError {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
http::Response::builder()
.status(self.status_code())
.header("Content-Type", "application/xml")
.extension(RuntimeErrorExtension::new(self.name().to_string()))
.body(crate::body::to_boxed(""))
.expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl IntoResponse<AwsJson10> for RuntimeError {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
http::Response::builder()
.status(self.status_code())
.header("Content-Type", "application/x-amz-json-1.0")
.extension(RuntimeErrorExtension::new(self.name().to_string()))
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization
.body(crate::body::to_boxed(""))
.expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl IntoResponse<AwsJson11> for RuntimeError {
fn into_response(self) -> http::Response<crate::body::BoxBody> {
http::Response::builder()
.status(self.status_code())
.header("Content-Type", "application/x-amz-json-1.1")
.extension(RuntimeErrorExtension::new(self.name().to_string()))
// See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization
.body(crate::body::to_boxed(""))
.expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues")
}
}
impl From<crate::rejection::RequestExtensionNotFoundRejection> for RuntimeError {
fn from(err: crate::rejection::RequestExtensionNotFoundRejection) -> Self {
RuntimeErrorKind::InternalFailure(crate::Error::new(err))
Self::InternalFailure(crate::Error::new(err))
}
}
impl From<crate::rejection::ResponseRejection> for RuntimeErrorKind {
impl From<crate::rejection::ResponseRejection> for RuntimeError {
fn from(err: crate::rejection::ResponseRejection) -> Self {
RuntimeErrorKind::Serialization(crate::Error::new(err))
Self::Serialization(crate::Error::new(err))
}
}
impl From<crate::rejection::RequestRejection> for RuntimeErrorKind {
impl From<crate::rejection::RequestRejection> for RuntimeError {
fn from(err: crate::rejection::RequestRejection) -> Self {
match err {
crate::rejection::RequestRejection::MissingContentType(_reason) => RuntimeErrorKind::UnsupportedMediaType,
_ => RuntimeErrorKind::Serialization(crate::Error::new(err)),
crate::rejection::RequestRejection::MissingContentType(_reason) => Self::UnsupportedMediaType,
_ => Self::Serialization(crate::Error::new(err)),
}
}
}