diff --git a/src/handler/future.rs b/src/handler/future.rs index 0810dc52..3c78cb42 100644 --- a/src/handler/future.rs +++ b/src/handler/future.rs @@ -1,43 +1,38 @@ //! Handler future types. use crate::body::{box_body, BoxBody}; +use futures_util::future::{BoxFuture, Either}; use http::{Method, Request, Response}; use http_body::Empty; use pin_project_lite::pin_project; use std::{ - convert::Infallible, + fmt, future::Future, pin::Pin, task::{Context, Poll}, }; -use tower::Service; - -opaque_future! { - /// The response future for [`IntoService`](super::IntoService). - pub type IntoServiceFuture = - futures_util::future::BoxFuture<'static, Result, Infallible>>; -} +use tower::{util::Oneshot, Service}; pin_project! { /// The response future for [`OnMethod`](super::OnMethod). - #[derive(Debug)] - pub struct OnMethodFuture + pub struct OnMethodFuture where - S: Service>, F: Service> { #[pin] - pub(super) inner: crate::routing::future::RouteFuture, + pub(super) inner: Either< + BoxFuture<'static, Result, F::Error>>, + Oneshot>, + >, pub(super) req_method: Method, } } -impl Future for OnMethodFuture +impl Future for OnMethodFuture where - S: Service, Response = Response>, - F: Service, Response = Response, Error = S::Error>, + F: Service, Response = Response>, { - type Output = Result, S::Error>; + type Output = Result, F::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -50,3 +45,12 @@ where } } } + +impl fmt::Debug for OnMethodFuture +where + F: Service>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OnMethodFuture").finish() + } +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs index c94f69ea..cda572da 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -4,11 +4,12 @@ use crate::{ body::{box_body, BoxBody}, extract::FromRequest, response::IntoResponse, - routing::{future::RouteFuture, EmptyRouter, MethodFilter}, + routing::{EmptyRouter, MethodFilter}, service::{HandleError, HandleErrorFromRouter}, }; use async_trait::async_trait; use bytes::Bytes; +use futures_util::future::Either; use http::{Request, Response}; use std::{ convert::Infallible, @@ -37,7 +38,7 @@ pub mod future; /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any(handler: H) -> OnMethod, EmptyRouter> +pub fn any(handler: H) -> OnMethod where H: Handler, { @@ -47,7 +48,7 @@ where /// Route `CONNECT` requests to the given handler. /// /// See [`get`] for an example. -pub fn connect(handler: H) -> OnMethod, EmptyRouter> +pub fn connect(handler: H) -> OnMethod where H: Handler, { @@ -57,7 +58,7 @@ where /// Route `DELETE` requests to the given handler. /// /// See [`get`] for an example. -pub fn delete(handler: H) -> OnMethod, EmptyRouter> +pub fn delete(handler: H) -> OnMethod where H: Handler, { @@ -83,7 +84,7 @@ where /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. -pub fn get(handler: H) -> OnMethod, EmptyRouter> +pub fn get(handler: H) -> OnMethod where H: Handler, { @@ -93,7 +94,7 @@ where /// Route `HEAD` requests to the given handler. /// /// See [`get`] for an example. -pub fn head(handler: H) -> OnMethod, EmptyRouter> +pub fn head(handler: H) -> OnMethod where H: Handler, { @@ -103,7 +104,7 @@ where /// Route `OPTIONS` requests to the given handler. /// /// See [`get`] for an example. -pub fn options(handler: H) -> OnMethod, EmptyRouter> +pub fn options(handler: H) -> OnMethod where H: Handler, { @@ -113,7 +114,7 @@ where /// Route `PATCH` requests to the given handler. /// /// See [`get`] for an example. -pub fn patch(handler: H) -> OnMethod, EmptyRouter> +pub fn patch(handler: H) -> OnMethod where H: Handler, { @@ -123,7 +124,7 @@ where /// Route `POST` requests to the given handler. /// /// See [`get`] for an example. -pub fn post(handler: H) -> OnMethod, EmptyRouter> +pub fn post(handler: H) -> OnMethod where H: Handler, { @@ -133,7 +134,7 @@ where /// Route `PUT` requests to the given handler. /// /// See [`get`] for an example. -pub fn put(handler: H) -> OnMethod, EmptyRouter> +pub fn put(handler: H) -> OnMethod where H: Handler, { @@ -143,7 +144,7 @@ where /// Route `TRACE` requests to the given handler. /// /// See [`get`] for an example. -pub fn trace(handler: H) -> OnMethod, EmptyRouter> +pub fn trace(handler: H) -> OnMethod where H: Handler, { @@ -165,14 +166,15 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on(method: MethodFilter, handler: H) -> OnMethod, EmptyRouter> +pub fn on(method: MethodFilter, handler: H) -> OnMethod where H: Handler, { OnMethod { method, - svc: handler.into_service(), + handler, fallback: EmptyRouter::method_not_allowed(), + _marker: PhantomData, } } @@ -191,7 +193,7 @@ pub(crate) mod sealed { /// /// See the [module docs](crate::handler) for more details. #[async_trait] -pub trait Handler: Sized { +pub trait Handler: Clone + Send + Sized + 'static { // This seals the trait. We cannot use the regular "sealed super trait" // approach due to coherence. #[doc(hidden)] @@ -231,23 +233,18 @@ pub trait Handler: Sized { /// /// When adding middleware that might fail its recommended to handle those /// errors. See [`Layered::handle_error`] for more details. - fn layer(self, layer: L) -> Layered + fn layer(self, layer: L) -> Layered where - L: Layer>, + L: Layer>, { - Layered::new(layer.layer(IntoService::new(self))) - } - - /// Convert the handler into a [`Service`]. - fn into_service(self) -> IntoService { - IntoService::new(self) + Layered::new(layer.layer(any(self))) } } #[async_trait] impl Handler for F where - F: FnOnce() -> Fut + Send + Sync, + F: FnOnce() -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, Res: IntoResponse, B: Send + 'static, @@ -268,7 +265,7 @@ macro_rules! impl_handler { #[allow(non_snake_case)] impl Handler for F where - F: FnOnce($head, $($tail,)*) -> Fut + Send + Sync, + F: FnOnce($head, $($tail,)*) -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, B: Send + 'static, Res: IntoResponse, @@ -334,9 +331,10 @@ where #[async_trait] impl Handler for Layered where - S: Service, Response = Response> + Send, + S: Service, Response = Response> + Clone + Send + 'static, S::Error: IntoResponse, S::Future: Send, + T: 'static, ReqBody: Send + 'static, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync + 'static, @@ -387,89 +385,59 @@ impl Layered { } } -/// An adapter that makes a [`Handler`] into a [`Service`]. -/// -/// Created with [`Handler::into_service`]. -pub struct IntoService { - handler: H, - _marker: PhantomData (B, T)>, +/// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and +/// allows chaining additional handlers. +pub struct OnMethod { + pub(crate) method: MethodFilter, + pub(crate) handler: H, + pub(crate) fallback: F, + pub(crate) _marker: PhantomData (B, T)>, } -impl IntoService { - fn new(handler: H) -> Self { - Self { - handler, - _marker: PhantomData, - } - } -} - -impl fmt::Debug for IntoService +impl fmt::Debug for OnMethod where - H: fmt::Debug, + T: fmt::Debug, + F: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IntoService") - .field("handler", &self.handler) + f.debug_struct("OnMethod") + .field("method", &self.method) + .field("handler", &format_args!("{}", std::any::type_name::())) + .field("fallback", &self.fallback) .finish() } } -impl Clone for IntoService +impl Clone for OnMethod where H: Clone, + F: Clone, { fn clone(&self) -> Self { Self { + method: self.method, handler: self.handler.clone(), + fallback: self.fallback.clone(), _marker: PhantomData, } } } -impl Service> for IntoService +impl Copy for OnMethod where - H: Handler + Clone + Send + 'static, - B: Send + 'static, + H: Copy, + F: Copy, { - type Response = Response; - type Error = Infallible; - type Future = future::IntoServiceFuture; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - // `IntoService` can only be constructed from async functions which are always ready, or from - // `Layered` which bufferes in `::call` and is therefore also always - // ready. - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let handler = self.handler.clone(); - let future = Box::pin(async move { - let res = Handler::call(handler, req).await; - Ok(res) - }); - future::IntoServiceFuture { future } - } } -/// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and -/// allows chaining additional handlers. -#[derive(Debug, Clone, Copy)] -pub struct OnMethod { - pub(crate) method: MethodFilter, - pub(crate) svc: S, - pub(crate) fallback: F, -} - -impl OnMethod { +impl OnMethod { /// Chain an additional handler that will accept all requests regardless of /// its HTTP method. /// /// See [`OnMethod::get`] for an example. - pub fn any(self, handler: H) -> OnMethod, Self> + pub fn any(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::all(), handler) } @@ -477,9 +445,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `CONNECT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn connect(self, handler: H) -> OnMethod, Self> + pub fn connect(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::CONNECT, handler) } @@ -487,9 +455,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `DELETE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn delete(self, handler: H) -> OnMethod, Self> + pub fn delete(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::DELETE, handler) } @@ -516,9 +484,9 @@ impl OnMethod { /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. - pub fn get(self, handler: H) -> OnMethod, Self> + pub fn get(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::GET | MethodFilter::HEAD, handler) } @@ -526,9 +494,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `HEAD` requests. /// /// See [`OnMethod::get`] for an example. - pub fn head(self, handler: H) -> OnMethod, Self> + pub fn head(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::HEAD, handler) } @@ -536,9 +504,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `OPTIONS` requests. /// /// See [`OnMethod::get`] for an example. - pub fn options(self, handler: H) -> OnMethod, Self> + pub fn options(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::OPTIONS, handler) } @@ -546,9 +514,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `PATCH` requests. /// /// See [`OnMethod::get`] for an example. - pub fn patch(self, handler: H) -> OnMethod, Self> + pub fn patch(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::PATCH, handler) } @@ -556,9 +524,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `POST` requests. /// /// See [`OnMethod::get`] for an example. - pub fn post(self, handler: H) -> OnMethod, Self> + pub fn post(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::POST, handler) } @@ -566,9 +534,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `PUT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn put(self, handler: H) -> OnMethod, Self> + pub fn put(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::PUT, handler) } @@ -576,9 +544,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `TRACE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn trace(self, handler: H) -> OnMethod, Self> + pub fn trace(self, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { self.on(MethodFilter::TRACE, handler) } @@ -602,30 +570,28 @@ impl OnMethod { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn on( - self, - method: MethodFilter, - handler: H, - ) -> OnMethod, Self> + pub fn on(self, method: MethodFilter, handler: H2) -> OnMethod where - H: Handler, + H2: Handler, { OnMethod { method, - svc: handler.into_service(), + handler, fallback: self, + _marker: PhantomData, } } } -impl Service> for OnMethod +impl Service> for OnMethod where - S: Service, Response = Response, Error = Infallible> + Clone, + H: Handler, F: Service, Response = Response, Error = Infallible> + Clone, + B: Send + 'static, { type Response = Response; type Error = Infallible; - type Future = future::OnMethodFuture; + type Future = future::OnMethodFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -634,16 +600,20 @@ where fn call(&mut self, req: Request) -> Self::Future { let req_method = req.method().clone(); - let f = if self.method.matches(req.method()) { - let fut = self.svc.clone().oneshot(req); - RouteFuture::a(fut) + let fut = if self.method.matches(req.method()) { + let handler = self.handler.clone(); + let fut = Box::pin(async move { + let res = Handler::call(handler, req).await; + Ok::<_, F::Error>(res) + }) as futures_util::future::BoxFuture<'static, _>; + Either::Left(fut) } else { let fut = self.fallback.clone().oneshot(req); - RouteFuture::b(fut) + Either::Right(fut) }; future::OnMethodFuture { - inner: f, + inner: fut, req_method, } } diff --git a/src/service/future.rs b/src/service/future.rs index d1c0b89e..b80aac5a 100644 --- a/src/service/future.rs +++ b/src/service/future.rs @@ -3,6 +3,7 @@ use crate::{ body::{box_body, BoxBody}, response::IntoResponse, + util::{Either, EitherProj}, }; use bytes::Bytes; use futures_util::ready; @@ -14,7 +15,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{BoxError, Service}; +use tower::{util::Oneshot, BoxError, Service}; pin_project! { /// Response future for [`HandleError`](super::HandleError). @@ -52,54 +53,41 @@ where } } -pin_project! { - /// Response future for [`BoxResponseBody`]. - #[derive(Debug)] - pub struct BoxResponseBodyFuture { - #[pin] - pub(super) future: F, - } -} - -impl Future for BoxResponseBodyFuture -where - F: Future, E>>, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ - type Output = Result, E>; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = ready!(self.project().future.poll(cx))?; - let res = res.map(box_body); - Poll::Ready(Ok(res)) - } -} - pin_project! { /// The response future for [`OnMethod`](super::OnMethod). - #[derive(Debug)] pub struct OnMethodFuture where S: Service>, F: Service> { #[pin] - pub(super) inner: crate::routing::future::RouteFuture, + pub(super) inner: Either< + Oneshot>, + Oneshot>, + >, + // pub(super) inner: crate::routing::future::RouteFuture, pub(super) req_method: Method, } } -impl Future for OnMethodFuture +impl Future for OnMethodFuture where - S: Service, Response = Response>, + S: Service, Response = Response> + Clone, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, F: Service, Response = Response, Error = S::Error>, { type Output = Result, S::Error>; + #[allow(warnings)] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let response = futures_util::ready!(this.inner.poll(cx))?; + + let response = match this.inner.project() { + EitherProj::A { inner } => ready!(inner.poll(cx))?.map(box_body), + EitherProj::B { inner } => ready!(inner.poll(cx))?, + }; + if this.req_method == &Method::HEAD { let response = response.map(|_| box_body(Empty::new())); Poll::Ready(Ok(response)) diff --git a/src/service/mod.rs b/src/service/mod.rs index ea7bed05..318757cf 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -105,7 +105,7 @@ pub mod future; /// Route requests to the given service regardless of the HTTP method. /// /// See [`get`] for an example. -pub fn any(svc: S) -> OnMethod, EmptyRouter> +pub fn any(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -115,7 +115,7 @@ where /// Route `CONNECT` requests to the given service. /// /// See [`get`] for an example. -pub fn connect(svc: S) -> OnMethod, EmptyRouter> +pub fn connect(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -125,7 +125,7 @@ where /// Route `DELETE` requests to the given service. /// /// See [`get`] for an example. -pub fn delete(svc: S) -> OnMethod, EmptyRouter> +pub fn delete(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -156,7 +156,7 @@ where /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. -pub fn get(svc: S) -> OnMethod, EmptyRouter> +pub fn get(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -166,7 +166,7 @@ where /// Route `HEAD` requests to the given service. /// /// See [`get`] for an example. -pub fn head(svc: S) -> OnMethod, EmptyRouter> +pub fn head(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -176,7 +176,7 @@ where /// Route `OPTIONS` requests to the given service. /// /// See [`get`] for an example. -pub fn options(svc: S) -> OnMethod, EmptyRouter> +pub fn options(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -186,7 +186,7 @@ where /// Route `PATCH` requests to the given service. /// /// See [`get`] for an example. -pub fn patch(svc: S) -> OnMethod, EmptyRouter> +pub fn patch(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -196,7 +196,7 @@ where /// Route `POST` requests to the given service. /// /// See [`get`] for an example. -pub fn post(svc: S) -> OnMethod, EmptyRouter> +pub fn post(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -206,7 +206,7 @@ where /// Route `PUT` requests to the given service. /// /// See [`get`] for an example. -pub fn put(svc: S) -> OnMethod, EmptyRouter> +pub fn put(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -216,7 +216,7 @@ where /// Route `TRACE` requests to the given service. /// /// See [`get`] for an example. -pub fn trace(svc: S) -> OnMethod, EmptyRouter> +pub fn trace(svc: S) -> OnMethod, B> where S: Service> + Clone, { @@ -243,38 +243,49 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on( - method: MethodFilter, - svc: S, -) -> OnMethod, EmptyRouter> +pub fn on(method: MethodFilter, svc: S) -> OnMethod, B> where S: Service> + Clone, { OnMethod { method, - svc: BoxResponseBody { - inner: svc, - _request_body: PhantomData, - }, + svc, fallback: EmptyRouter::method_not_allowed(), + _request_body: PhantomData, } } /// A [`Service`] that accepts requests based on a [`MethodFilter`] and allows /// chaining additional services. -#[derive(Clone, Debug)] -pub struct OnMethod { +#[derive(Debug)] // TODO(david): don't require debug for B +pub struct OnMethod { pub(crate) method: MethodFilter, pub(crate) svc: S, pub(crate) fallback: F, + pub(crate) _request_body: PhantomData B>, } -impl OnMethod { +impl Clone for OnMethod +where + S: Clone, + F: Clone, +{ + fn clone(&self) -> Self { + Self { + method: self.method, + svc: self.svc.clone(), + fallback: self.fallback.clone(), + _request_body: PhantomData, + } + } +} + +impl OnMethod { /// Chain an additional service that will accept all requests regardless of /// its HTTP method. /// /// See [`OnMethod::get`] for an example. - pub fn any(self, svc: T) -> OnMethod, Self> + pub fn any(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -284,7 +295,7 @@ impl OnMethod { /// Chain an additional service that will only accept `CONNECT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn connect(self, svc: T) -> OnMethod, Self> + pub fn connect(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -294,7 +305,7 @@ impl OnMethod { /// Chain an additional service that will only accept `DELETE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn delete(self, svc: T) -> OnMethod, Self> + pub fn delete(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -330,7 +341,7 @@ impl OnMethod { /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. - pub fn get(self, svc: T) -> OnMethod, Self> + pub fn get(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -340,7 +351,7 @@ impl OnMethod { /// Chain an additional service that will only accept `HEAD` requests. /// /// See [`OnMethod::get`] for an example. - pub fn head(self, svc: T) -> OnMethod, Self> + pub fn head(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -350,7 +361,7 @@ impl OnMethod { /// Chain an additional service that will only accept `OPTIONS` requests. /// /// See [`OnMethod::get`] for an example. - pub fn options(self, svc: T) -> OnMethod, Self> + pub fn options(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -360,7 +371,7 @@ impl OnMethod { /// Chain an additional service that will only accept `PATCH` requests. /// /// See [`OnMethod::get`] for an example. - pub fn patch(self, svc: T) -> OnMethod, Self> + pub fn patch(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -370,7 +381,7 @@ impl OnMethod { /// Chain an additional service that will only accept `POST` requests. /// /// See [`OnMethod::get`] for an example. - pub fn post(self, svc: T) -> OnMethod, Self> + pub fn post(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -380,7 +391,7 @@ impl OnMethod { /// Chain an additional service that will only accept `PUT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn put(self, svc: T) -> OnMethod, Self> + pub fn put(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -390,7 +401,7 @@ impl OnMethod { /// Chain an additional service that will only accept `TRACE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn trace(self, svc: T) -> OnMethod, Self> + pub fn trace(self, svc: T) -> OnMethod where T: Service> + Clone, { @@ -422,17 +433,15 @@ impl OnMethod { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn on(self, method: MethodFilter, svc: T) -> OnMethod, Self> + pub fn on(self, method: MethodFilter, svc: T) -> OnMethod where T: Service> + Clone, { OnMethod { method, - svc: BoxResponseBody { - inner: svc, - _request_body: PhantomData, - }, + svc, fallback: self, + _request_body: PhantomData, } } @@ -459,9 +468,11 @@ impl OnMethod { // this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean // that up, but not sure its possible. -impl Service> for OnMethod +impl Service> for OnMethod where - S: Service, Response = Response> + Clone, + S: Service, Response = Response> + Clone, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, F: Service, Response = Response, Error = S::Error> + Clone, { type Response = Response; @@ -473,14 +484,16 @@ where } fn call(&mut self, req: Request) -> Self::Future { + use crate::util::Either; + let req_method = req.method().clone(); let f = if self.method.matches(req.method()) { let fut = self.svc.clone().oneshot(req); - RouteFuture::a(fut) + Either::A { inner: fut } } else { let fut = self.fallback.clone().oneshot(req); - RouteFuture::b(fut) + Either::B { inner: fut } }; future::OnMethodFuture { @@ -574,55 +587,6 @@ where } } -/// A [`Service`] that boxes response bodies. -pub struct BoxResponseBody { - inner: S, - _request_body: PhantomData B>, -} - -impl Clone for BoxResponseBody -where - S: Clone, -{ - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - _request_body: PhantomData, - } - } -} - -impl fmt::Debug for BoxResponseBody -where - S: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("BoxResponseBody") - .field("inner", &self.inner) - .finish() - } -} - -impl Service> for BoxResponseBody -where - S: Service, Response = Response> + Clone, - ResBody: http_body::Body + Send + Sync + 'static, - ResBody::Error: Into + Send + Sync + 'static, -{ - type Response = Response; - type Error = S::Error; - type Future = future::BoxResponseBodyFuture>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let fut = self.inner.clone().oneshot(req); - future::BoxResponseBodyFuture { future: fut } - } -} - /// ```compile_fail /// use crate::{service::ServiceExt, prelude::*}; /// use tower::service_fn; diff --git a/src/tests/get_to_head.rs b/src/tests/get_to_head.rs index 55c49bb0..3abc081b 100644 --- a/src/tests/get_to_head.rs +++ b/src/tests/get_to_head.rs @@ -39,17 +39,19 @@ mod for_handlers { mod for_services { use super::*; use crate::service::get; + use headers::HeaderValue; #[tokio::test] async fn get_handles_head() { let app = route( "/", - get((|| async { - let mut headers = HeaderMap::new(); - headers.insert("x-some-header", "foobar".parse().unwrap()); - (headers, "you shouldn't see this") - }) - .into_service()), + get(service_fn(|req: Request| async move { + let res = Response::builder() + .header("x-some-header", "foobar".parse::().unwrap()) + .body(Body::from("you shouldn't see this")) + .unwrap(); + Ok::<_, Infallible>(res) + })), ); // don't use reqwest because it always strips bodies from HEAD responses diff --git a/src/tests/mod.rs b/src/tests/mod.rs index caa4aaad..11228055 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -354,10 +354,7 @@ async fn routing_between_services() { }), ), ) - .route( - "/two", - service::on(MethodFilter::GET, handle.into_service()), - ); + .route("/two", service::on(MethodFilter::GET, any(handle))); let addr = run_in_background(app).await; diff --git a/src/util.rs b/src/util.rs index 1c05533c..a731f741 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use pin_project_lite::pin_project; use std::ops::Deref; /// A string like type backed by `Bytes` making it cheap to clone. @@ -28,3 +29,11 @@ impl ByteStr { std::str::from_utf8(&self.0).unwrap() } } + +pin_project! { + #[project = EitherProj] + pub(crate) enum Either { + A { #[pin] inner: A }, + B { #[pin] inner: B }, + } +}