diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 95cd3c3b7a..27fcc32da5 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -32,7 +32,8 @@ regex = "1.0" serde_urlencoded = "0.7" thiserror = "1" tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4" } +tower = "0.4" +tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] } [dev-dependencies] pretty_assertions = "1" diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 24eb452aec..70dff6f9c6 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -16,13 +16,16 @@ //! [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait use self::{future::RouterFuture, request_spec::RequestSpec}; -use crate::body::{Body, BoxBody}; +use crate::body::{box_body, Body, BoxBody, HttpBody}; +use crate::BoxError; use http::{Request, Response, StatusCode}; use std::{ convert::Infallible, task::{Context, Poll}, }; -use tower::{Service, ServiceExt}; +use tower::layer::Layer; +use tower::{Service, ServiceBuilder, ServiceExt}; +use tower_http::map_response_body::MapResponseBodyLayer; pub mod future; mod into_make_service; @@ -34,7 +37,7 @@ pub use self::{into_make_service::IntoMakeService, route::Route}; #[derive(Debug)] pub struct Router { - routes: Vec>, + routes: Vec<(Route, RequestSpec)>, } impl Clone for Router { @@ -70,7 +73,7 @@ where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, { - self.routes.push(Route::new(svc, request_spec)); + self.routes.push((Route::new(svc), request_spec)); self } @@ -84,6 +87,28 @@ where pub fn into_make_service(self) -> IntoMakeService { IntoMakeService::new(self) } + + /// Apply a [`tower::Layer`] to the router. + /// + /// All requests to the router will be processed by the layer's + /// corresponding middleware. + /// + /// This can be used to add additional processing to a request for a group + /// of routes. + pub fn layer(self, layer: L) -> Router + where + L: Layer>, + L::Service: + Service, Response = Response, Error = Infallible> + Clone + Send + 'static, + >>::Future: Send + 'static, + NewResBody: HttpBody + Send + 'static, + NewResBody::Error: Into, + { + let layer = ServiceBuilder::new().layer_fn(Route::new).layer(MapResponseBodyLayer::new(box_body)).layer(layer); + let routes = + self.routes.into_iter().map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)).collect(); + Router { routes } + } } impl Service> for Router @@ -103,8 +128,8 @@ where fn call(&mut self, req: Request) -> Self::Future { let mut method_not_allowed = false; - for route in &self.routes { - match route.matches(&req) { + for (route, request_spec) in &self.routes { + match request_spec.matches(&req) { request_spec::Match::Yes => { return RouterFuture::from_oneshot(route.clone().oneshot(req)); } @@ -148,7 +173,7 @@ mod tests { #[inline] fn call(&mut self, req: Request) -> Self::Future { - let body = box_body(Body::from(format!("{} :: {}", self.0, String::from(req.uri().to_string())))); + let body = box_body(Body::from(format!("{} :: {}", self.0, req.uri().to_string()))); let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) }; Box::pin(fut) } diff --git a/rust-runtime/aws-smithy-http-server/src/routing/route.rs b/rust-runtime/aws-smithy-http-server/src/routing/route.rs index 713a137e7f..72a9f401ff 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/route.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/route.rs @@ -48,31 +48,24 @@ use std::{ use tower::Service; use tower::{util::Oneshot, ServiceExt}; -use super::request_spec::{Match, RequestSpec}; - /// How routes are stored inside a [`Router`](super::Router). pub struct Route { service: CloneBoxService, Response, Infallible>, - request_spec: RequestSpec, } impl Route { - pub(super) fn new(svc: T, request_spec: RequestSpec) -> Self + pub(super) fn new(svc: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, { - Self { service: CloneBoxService::new(svc), request_spec } - } - - pub(super) fn matches(&self, req: &Request) -> Match { - self.request_spec.matches(req) + Self { service: CloneBoxService::new(svc) } } } impl Clone for Route { fn clone(&self) -> Self { - Self { service: self.service.clone(), request_spec: self.request_spec.clone() } + Self { service: self.service.clone() } } }