Fix `ServiceExt::handle_error` footgun (#120)

As described in
https://github.com/tokio-rs/axum/pull/108#issuecomment-892811637, a
`HandleError` created from `axum::ServiceExt::handle_error` should _not_
implement `RoutingDsl` as that leads to confusing routing behavior.

The technique used here of adding another type parameter to
`HandleError` isn't very clean, I think. But the alternative is
duplicating `HandleError` and having two versions, which I think is less
desirable.
This commit is contained in:
David Pedersen 2021-08-07 16:44:12 +02:00 committed by GitHub
parent b5b9db47db
commit 95d7582d28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 12 deletions

View File

@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Breaking changes
- Ensure a `HandleError` service created from `axum::ServiceExt::handle_error`
_does not_ implement `RoutingDsl` as that could lead to confusing routing
behavior. ([#120](https://github.com/tokio-rs/axum/pull/120))
- Remove `QueryStringMissing` as it was no longer being used
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))

View File

@ -5,7 +5,7 @@ use crate::{
extract::FromRequest,
response::IntoResponse,
routing::{future::RouteFuture, EmptyRouter, MethodFilter},
service::HandleError,
service::{HandleError, HandleErrorFromRouter},
};
use async_trait::async_trait;
use bytes::Bytes;
@ -371,7 +371,7 @@ impl<S, T> Layered<S, T> {
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self,
f: F,
) -> Layered<HandleError<S, F, ReqBody>, T>
) -> Layered<HandleError<S, F, ReqBody, HandleErrorFromRouter>, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
F: FnOnce(S::Error) -> Result<Res, E>,

View File

@ -6,6 +6,7 @@ use crate::{
buffer::MpscBuffer,
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
response::IntoResponse,
service::HandleErrorFromRouter,
util::ByteStr,
};
use async_trait::async_trait;
@ -716,7 +717,7 @@ impl<S> Layered<S> {
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self,
f: F,
) -> crate::service::HandleError<S, F, ReqBody>
) -> crate::service::HandleError<S, F, ReqBody, HandleErrorFromRouter>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Result<Res, E>,

View File

@ -462,13 +462,13 @@ where
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error).
/// See those methods for more details.
pub struct HandleError<S, F, B> {
pub struct HandleError<S, F, B, T> {
inner: S,
f: F,
_marker: PhantomData<fn() -> B>,
_marker: PhantomData<fn() -> (B, T)>,
}
impl<S, F, B> Clone for HandleError<S, F, B>
impl<S, F, B, T> Clone for HandleError<S, F, B, T>
where
S: Clone,
F: Clone,
@ -478,11 +478,23 @@ where
}
}
impl<S, F, B> crate::routing::RoutingDsl for HandleError<S, F, B> {}
/// Maker type used for [`HandleError`] to indicate that it should implement
/// [`RoutingDsl`](crate::routing::RoutingDsl).
#[non_exhaustive]
#[derive(Debug)]
pub struct HandleErrorFromRouter;
impl<S, F, B> crate::sealed::Sealed for HandleError<S, F, B> {}
/// Maker type used for [`HandleError`] to indicate that it should _not_ implement
/// [`RoutingDsl`](crate::routing::RoutingDsl).
#[non_exhaustive]
#[derive(Debug)]
pub struct HandleErrorFromService;
impl<S, F, B> HandleError<S, F, B> {
impl<S, F, B> crate::routing::RoutingDsl for HandleError<S, F, B, HandleErrorFromRouter> {}
impl<S, F, B> crate::sealed::Sealed for HandleError<S, F, B, HandleErrorFromRouter> {}
impl<S, F, B, T> HandleError<S, F, B, T> {
pub(crate) fn new(inner: S, f: F) -> Self {
Self {
inner,
@ -492,7 +504,7 @@ impl<S, F, B> HandleError<S, F, B> {
}
}
impl<S, F, B> fmt::Debug for HandleError<S, F, B>
impl<S, F, B, T> fmt::Debug for HandleError<S, F, B, T>
where
S: fmt::Debug,
{
@ -504,7 +516,7 @@ where
}
}
impl<S, F, ReqBody, ResBody, Res, E> Service<Request<ReqBody>> for HandleError<S, F, ReqBody>
impl<S, F, ReqBody, ResBody, Res, E, T> Service<Request<ReqBody>> for HandleError<S, F, ReqBody, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Result<Res, E> + Clone,
@ -570,7 +582,7 @@ pub trait ServiceExt<ReqBody, ResBody>:
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
///
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody>
fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody, HandleErrorFromService>
where
Self: Sized,
F: FnOnce(Self::Error) -> Result<Res, E>,
@ -645,3 +657,21 @@ where
future::BoxResponseBodyFuture { future: fut }
}
}
/// ```compile_fail
/// use crate::{service::ServiceExt, prelude::*};
/// use tower::service_fn;
/// use hyper::Body;
/// use http::{Request, Response, StatusCode};
///
/// let svc = service_fn(|_: Request<Body>| async {
/// Ok::<_, hyper::Error>(Response::new(Body::empty()))
/// })
/// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR));
///
/// // `.route` should not compile, ie `HandleError` created from any
/// // random service should not implement `RoutingDsl`
/// svc.route::<_, Body>("/", get(|| async {}));
/// ```
#[allow(dead_code)]
fn compile_fail_tests() {}