diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs index ca7e34d1..b51a5162 100644 --- a/examples/global-404-handler/src/main.rs +++ b/examples/global-404-handler/src/main.rs @@ -24,8 +24,8 @@ async fn main() { // build our application with a route let app = Router::new().route("/", get(handler)); - // make sure this is added as the very last thing - let app = app.or(handler_404.into_service()); + // add a fallback service for handling routes to unknown paths + let app = app.fallback(handler_404.into_service()); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/src/lib.rs b/src/lib.rs index 0774ba17..f21b681e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,10 +7,11 @@ //! - [Handlers](#handlers) //! - [Debugging handler type errors](#debugging-handler-type-errors) //! - [Routing](#routing) -//! - [Routing to any `Service`](#routing-to-any-service) -//! - [Routing to fallible services](#routing-to-fallible-services) //! - [Wildcard routes](#wildcard-routes) //! - [Nesting routes](#nesting-routes) +//! - [Fallback routes](#fallback-routes) +//! - [Routing to any `Service`](#routing-to-any-service) +//! - [Routing to fallible services](#routing-to-fallible-services) //! - [Extractors](#extractors) //! - [Common extractors](#common-extractors) //! - [Applying multiple extractors](#applying-multiple-extractors) @@ -143,7 +144,7 @@ //! //! # Routing //! -//! Routing between handlers looks like this: +//! [`Router::route`] is the main way to add routes: //! //! ```rust,no_run //! use axum::{ @@ -174,11 +175,125 @@ //! Routes can also be dynamic like `/users/:id`. See [extractors](#extractors) //! for more details. //! -//! You can also define routes separately and merge them with [`Router::or`]. +//! You can also define routes separately and merge them with [`Router::merge`]. //! //! Routes are not allowed to overlap and will panic if an overlapping route is //! added. This also means the order in which routes are added doesn't matter. //! +//! ## Wildcard routes +//! +//! axum also supports wildcard routes: +//! +//! ```rust,no_run +//! use axum::{ +//! routing::get, +//! Router, +//! }; +//! +//! let app = Router::new() +//! // this matches any request that starts with `/api` +//! .route("/api/*rest", get(|| async { /* ... */ })); +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! The matched path can be extracted via [`extract::Path`]: +//! +//! ```rust,no_run +//! use axum::{ +//! routing::get, +//! extract::Path, +//! Router, +//! }; +//! +//! let app = Router::new().route("/api/*rest", get(|Path(rest): Path| async { +//! // `rest` will be everything after `/api` +//! })); +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! ## Nesting routes +//! +//! Routes can be nested by calling [`Router::nest`](routing::Router::nest): +//! +//! ```rust,no_run +//! use axum::{ +//! body::{Body, BoxBody}, +//! http::Request, +//! routing::get, +//! Router, +//! }; +//! use tower_http::services::ServeFile; +//! use http::Response; +//! +//! fn api_routes() -> Router { +//! Router::new() +//! .route("/users", get(|_: Request| async { /* ... */ })) +//! } +//! +//! let app = Router::new() +//! .route("/", get(|_: Request| async { /* ... */ })) +//! .nest("/api", api_routes()); +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! Note that nested routes will not see the orignal request URI but instead +//! have the matched prefix stripped. This is necessary for services like static +//! file serving to work. Use [`OriginalUri`] if you need the original request +//! URI. +//! +//! Nested routes are similar to wild card routes. The difference is that +//! wildcard routes still see the whole URI whereas nested routes will have +//! the prefix stripped. +//! +//! ```rust +//! use axum::{routing::get, http::Uri, Router}; +//! +//! let app = Router::new() +//! .route("/foo/*rest", get(|uri: Uri| async { +//! // `uri` will contain `/foo` +//! })) +//! .nest("/bar", get(|uri: Uri| async { +//! // `uri` will _not_ contain `/bar` +//! })); +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! ## Fallback routes +//! +//! By default axum will respond with an empty `404 Not Found` response to unhandled requests. To +//! override that you can use [`Router::fallback`]: +//! +//! ```rust +//! use axum::{ +//! Router, +//! routing::get, +//! handler::Handler, +//! response::IntoResponse, +//! http::{StatusCode, Uri}, +//! }; +//! +//! async fn fallback(uri: Uri) -> impl IntoResponse { +//! (StatusCode::NOT_FOUND, format!("No route for {}", uri)) +//! } +//! +//! let app = Router::new() +//! .route("/foo", get(|| async { /* ... */ })) +//! .fallback(fallback.into_service()); +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! See [`Router::fallback`] for more details. +//! //! ## Routing to any [`Service`] //! //! axum also supports routing to general [`Service`]s: @@ -314,92 +429,6 @@ //! See ["Error handling"](#error-handling) for more details on [`handle_error`] //! and error handling in general. //! -//! ## Wildcard routes -//! -//! axum also supports wildcard routes: -//! -//! ```rust,no_run -//! use axum::{ -//! routing::get, -//! Router, -//! }; -//! -//! let app = Router::new() -//! // this matches any request that starts with `/api` -//! .route("/api/*rest", get(|| async { /* ... */ })); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! The matched path can be extracted via [`extract::Path`]: -//! -//! ```rust,no_run -//! use axum::{ -//! routing::get, -//! extract::Path, -//! Router, -//! }; -//! -//! let app = Router::new().route("/api/*rest", get(|Path(rest): Path| async { -//! // `rest` will be everything after `/api` -//! })); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! ## Nesting routes -//! -//! Routes can be nested by calling [`Router::nest`](routing::Router::nest): -//! -//! ```rust,no_run -//! use axum::{ -//! body::{Body, BoxBody}, -//! http::Request, -//! routing::get, -//! Router, -//! }; -//! use tower_http::services::ServeFile; -//! use http::Response; -//! -//! fn api_routes() -> Router { -//! Router::new() -//! .route("/users", get(|_: Request| async { /* ... */ })) -//! } -//! -//! let app = Router::new() -//! .route("/", get(|_: Request| async { /* ... */ })) -//! .nest("/api", api_routes()); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! Note that nested routes will not see the orignal request URI but instead -//! have the matched prefix stripped. This is necessary for services like static -//! file serving to work. Use [`OriginalUri`] if you need the original request -//! URI. -//! -//! Nested routes are similar to wild card routes. The difference is that -//! wildcard routes still see the whole URI whereas nested routes will have -//! the prefix stripped. -//! -//! ```rust -//! use axum::{routing::get, http::Uri, Router}; -//! -//! let app = Router::new() -//! .route("/foo/*rest", get(|uri: Uri| async { -//! // `uri` will contain `/foo` -//! })) -//! .nest("/bar", get(|uri: Uri| async { -//! // `uri` will _not_ contain `/bar` -//! })); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! //! # Extractors //! //! An extractor is a type that implements [`FromRequest`]. Extractors is how @@ -862,7 +891,7 @@ //! Note that [`Router::layer`] applies the middleware to all previously added //! routes, of that particular `Router`. If you need multiple groups of routes //! with different middleware build them separately and combine them with -//! [`Router::or`]: +//! [`Router::merge`]: //! //! ```rust,no_run //! use axum::{ @@ -883,7 +912,7 @@ //! .route("/requires-auth", get(handler)) //! .layer(MyAuthLayer::new()); //! -//! let app = foo.or(bar); +//! let app = foo.merge(bar); //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; @@ -1148,7 +1177,7 @@ //! [`IntoResponse`]: crate::response::IntoResponse //! [`Timeout`]: tower::timeout::Timeout //! [examples]: https://github.com/tokio-rs/axum/tree/main/examples -//! [`Router::or`]: crate::routing::Router::or +//! [`Router::merge`]: crate::routing::Router::merge //! [`axum::Server`]: hyper::server::Server //! [`OriginalUri`]: crate::extract::OriginalUri //! [`Service`]: tower::Service diff --git a/src/routing/future.rs b/src/routing/future.rs index 4c755cb6..6b01c3aa 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -1,9 +1,6 @@ //! Future types. -use crate::{ - body::BoxBody, - routing::{FromEmptyRouter, UriStack}, -}; +use crate::body::BoxBody; use http::{Request, Response}; use pin_project_lite::pin_project; use std::{ @@ -15,120 +12,49 @@ use std::{ use tower::util::Oneshot; use tower_service::Service; +opaque_future! { + /// Response future for [`Router`](super::Router). + pub type RouterFuture = + futures_util::future::Either< + Oneshot, Request>, + std::future::Ready, Infallible>>, + >; +} + +opaque_future! { + /// Response future for [`Route`](super::Route). + pub type RouteFuture = + futures_util::future::BoxFuture<'static, Result, Infallible>>; +} + opaque_future! { /// Response future for [`EmptyRouter`](super::EmptyRouter). pub type EmptyRouterFuture = std::future::Ready, E>>; } -opaque_future! { - /// Response future for [`Routes`](super::Routes). - pub type RoutesFuture = - futures_util::future::BoxFuture<'static, Result, Infallible>>; -} - -pin_project! { - /// The response future for [`Route`](super::Route). - #[derive(Debug)] - pub(crate) struct RouteFuture - where - S: Service>, - F: Service> - { - #[pin] - state: RouteFutureInner, - } -} - -impl RouteFuture -where - S: Service>, - F: Service>, -{ - pub(crate) fn a(a: Oneshot>) -> Self { - RouteFuture { - state: RouteFutureInner::A { a }, - } - } - - pub(crate) fn b(b: Oneshot>) -> Self { - RouteFuture { - state: RouteFutureInner::B { b }, - } - } -} - -pin_project! { - #[project = RouteFutureInnerProj] - #[derive(Debug)] - enum RouteFutureInner - where - S: Service>, - F: Service>, - { - A { - #[pin] - a: Oneshot>, - }, - B { - #[pin] - b: Oneshot> - }, - } -} - -impl Future for RouteFuture -where - S: Service, Response = Response, Error = Infallible>, - F: Service, Response = Response, Error = Infallible>, - B: Send + Sync + 'static, -{ - type Output = Result, Infallible>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project().state.project() { - RouteFutureInnerProj::A { a } => a.poll(cx), - RouteFutureInnerProj::B { b } => b.poll(cx), - } - } -} - pin_project! { /// The response future for [`Nested`](super::Nested). #[derive(Debug)] - pub(crate) struct NestedFuture + pub(crate) struct NestedFuture where S: Service>, - F: Service> { #[pin] - pub(super) inner: RouteFuture, + pub(super) inner: Oneshot> } } -impl Future for NestedFuture +impl Future for NestedFuture where S: Service, Response = Response, Error = Infallible>, - F: Service, Response = Response, Error = Infallible>, B: Send + Sync + 'static, { type Output = Result, Infallible>; + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut res: Response<_> = futures_util::ready!(self.project().inner.poll(cx)?); - - // `nest` mutates the URI of the request so if it turns out no route matched - // we need to reset the URI so the next routes see the original URI - // - // That requires using a stack since we can have arbitrarily nested routes - if let Some(from_empty_router) = res.extensions_mut().get_mut::>() { - let uri = UriStack::pop(&mut from_empty_router.request); - if let Some(uri) = uri { - *from_empty_router.request.uri_mut() = uri; - } - } - - Poll::Ready(Ok(res)) + self.project().inner.poll(cx) } } diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 59bbef70..56970452 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s and handlers. -use self::future::{EmptyRouterFuture, NestedFuture, RouteFuture, RoutesFuture}; +use self::future::{EmptyRouterFuture, NestedFuture, RouteFuture, RouterFuture}; use crate::{ body::{box_body, Body, BoxBody}, clone_box_service::CloneBoxService, @@ -13,17 +13,17 @@ use crate::{ }; use bytes::Bytes; use http::{Request, Response, StatusCode, Uri}; -use matchit::Node; use std::{ borrow::Cow, + collections::HashMap, convert::Infallible, fmt, future::ready, marker::PhantomData, task::{Context, Poll}, }; -use tower::util::ServiceExt; -use tower_http::map_response_body::MapResponseBody; +use tower::{util::ServiceExt, ServiceBuilder}; +use tower_http::map_response_body::MapResponseBodyLayer; use tower_layer::Layer; use tower_service::Service; @@ -32,7 +32,6 @@ pub mod handler_method_router; pub mod service_method_router; mod method_filter; -mod or; pub use self::method_filter::MethodFilter; @@ -41,7 +40,7 @@ pub use self::handler_method_router::{ any, connect, delete, get, head, on, options, patch, post, put, trace, MethodRouter, }; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct RouteId(u64); impl RouteId { @@ -54,8 +53,9 @@ impl RouteId { /// The router type for composing handlers and services. pub struct Router { - routes: Routes, - node: Node, + routes: HashMap>, + node: Node, + fallback: Option>, } impl Clone for Router { @@ -63,6 +63,7 @@ impl Clone for Router { Self { routes: self.routes.clone(), node: self.node.clone(), + fallback: self.fallback.clone(), } } } @@ -80,11 +81,13 @@ impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") .field("routes", &self.routes) + .field("node", &self.node) + .field("fallback", &self.fallback) .finish() } } -const NEST_TAIL_PARAM: &str = "__axum_nest"; +const NEST_TAIL_PARAM: &str = "__axum_internal_nest_capture"; impl Router where @@ -96,19 +99,23 @@ where /// all requests. pub fn new() -> Self { Self { - routes: Routes(CloneBoxService::new(EmptyRouter::not_found())), - node: Node::new(), + routes: Default::default(), + node: Default::default(), + fallback: None, } } /// Add another route to the router. /// /// `path` is a string of path segments separated by `/`. Each segment - /// can be either concrete or a capture: + /// can be either concrete, a capture, or a wildcard: /// /// - `/foo/bar/baz` will only match requests where the path is `/foo/bar/bar`. /// - `/:foo` will match any route with exactly one segment _and_ it will /// capture the first segment and store it at the key `foo`. + /// - `/foo/bar/*rest` will match all requests that start with `/foo/bar` + /// and any number of segments after that. It will also create a capture + /// with the key `rest` that contains the matched segments. /// /// `service` is the [`Service`] that should receive the request if the path /// matches `path`. @@ -116,13 +123,14 @@ where /// # Example /// /// ```rust - /// use axum::{routing::{get, delete}, Router}; + /// use axum::{Router, routing::{get, delete}, extract::Path}; /// /// let app = Router::new() /// .route("/", get(root)) /// .route("/users", get(list_users).post(create_user)) /// .route("/users/:id", get(show_user)) - /// .route("/api/:version/users/:id/action", delete(do_thing)); + /// .route("/api/:version/users/:id/action", delete(do_users_action)) + /// .route("/assets/*path", get(serve_asset)); /// /// async fn root() { /* ... */ } /// @@ -132,7 +140,9 @@ where /// /// async fn show_user() { /* ... */ } /// - /// async fn do_thing() { /* ... */ } + /// async fn do_users_action() { /* ... */ } + /// + /// async fn serve_asset(Path(path): Path) { /* ... */ } /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; @@ -194,14 +204,9 @@ where panic!("Invalid route: {}", err); } - Router { - routes: Routes(CloneBoxService::new(Route { - id, - svc, - fallback: self.routes, - })), - node: self.node, - } + self.routes.insert(id, Route(CloneBoxService::new(svc))); + + self } /// Nest a group of routes (or a [`Service`]) at some path. @@ -294,7 +299,7 @@ where /// # }; /// ``` /// - /// # Wildcard routes + /// # Differences to wildcard routes /// /// Nested routes are similar to wildcard routes. The difference is that /// wildcard routes still see the whole URI whereas nested routes will have @@ -345,14 +350,10 @@ where panic!("Invalid route: {}", err); } - Router { - routes: Routes(CloneBoxService::new(Nested { - id, - svc, - fallback: self.routes, - })), - node: self.node, - } + self.routes + .insert(id, Route(CloneBoxService::new(Nested { svc }))); + + self } /// Apply a [`tower::Layer`] to the router. @@ -424,7 +425,7 @@ where /// ``` pub fn layer(self, layer: L) -> Router where - L: Layer>, + L: Layer>, L::Service: Service< Request, Response = Response, @@ -436,7 +437,28 @@ where LayeredResBody: http_body::Body + Send + Sync + 'static, LayeredResBody::Error: Into, { - self.map(|svc| MapResponseBody::new(layer.layer(svc), box_body)) + let layer = ServiceBuilder::new() + .layer_fn(Route) + .layer_fn(CloneBoxService::new) + .layer(MapResponseBodyLayer::new(box_body)) + .layer(layer); + + let routes = self + .routes + .into_iter() + .map(|(id, route)| { + let route = Layer::layer(&layer, route); + (id, route) + }) + .collect::>>(); + + let fallback = self.fallback.map(|fallback| Layer::layer(&layer, fallback)); + + Router { + routes, + node: self.node, + fallback, + } } /// Convert this router into a [`MakeService`], that is a [`Service`] who's @@ -578,12 +600,127 @@ where /// let team_routes = Router::new().route("/teams", get(teams_list)); /// /// // combine them into one - /// let app = user_routes.or(team_routes); + /// let app = user_routes.merge(team_routes); /// # async { /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn or(self, other: T) -> Self + pub fn merge(mut self, other: Router) -> Self { + let Router { + routes, + node, + fallback, + } = other; + + if let Err(err) = self.node.merge(node) { + panic!("Invalid route: {}", err); + } + + for (id, route) in routes { + assert!(self.routes.insert(id, route).is_none()); + } + + if let Some(new_fallback) = fallback { + self.fallback = Some(new_fallback); + } + + self + } + + /// Add a fallback service to the router. + /// + /// This service will be called if no routes matches the incoming request. + /// + /// ```rust + /// use axum::{ + /// Router, + /// routing::get, + /// handler::Handler, + /// response::IntoResponse, + /// http::{StatusCode, Uri}, + /// }; + /// + /// let app = Router::new() + /// .route("/foo", get(|| async { /* ... */ })) + /// .fallback(fallback.into_service()); + /// + /// async fn fallback(uri: Uri) -> impl IntoResponse { + /// (StatusCode::NOT_FOUND, format!("No route for {}", uri)) + /// } + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + /// + /// Fallbacks only apply to routes that aren't matched by anything in the + /// router. If a handler is matched by a request but returns 404 the + /// fallback is not called. + /// + /// ## When used with `Router::merge` + /// + /// If a router with a fallback is merged with another router that also has + /// a fallback the fallback of the second router will be used: + /// + /// ```rust + /// use axum::{ + /// Router, + /// routing::get, + /// handler::Handler, + /// response::IntoResponse, + /// http::{StatusCode, Uri}, + /// }; + /// + /// let one = Router::new() + /// .route("/one", get(|| async { /* ... */ })) + /// .fallback(fallback_one.into_service()); + /// + /// let two = Router::new() + /// .route("/two", get(|| async { /* ... */ })) + /// .fallback(fallback_two.into_service()); + /// + /// let app = one.merge(two); + /// + /// async fn fallback_one() -> impl IntoResponse { /* ... */ } + /// async fn fallback_two() -> impl IntoResponse { /* ... */ } + /// + /// // the fallback for `app` is `fallback_two` + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + /// + /// If only one of the routers have a fallback that will be used in the + /// merged router. + /// + /// ## When used with `Router::nest` + /// + /// If a router with a fallback is nested inside another router the fallback + /// will only apply to requests that matches the prefix: + /// + /// ```rust + /// use axum::{ + /// Router, + /// routing::get, + /// handler::Handler, + /// response::IntoResponse, + /// http::{StatusCode, Uri}, + /// }; + /// + /// let api = Router::new() + /// .route("/", get(|| async { /* ... */ })) + /// .fallback(api_fallback.into_service()); + /// + /// let app = Router::new().nest("/api", api); + /// + /// async fn api_fallback() -> impl IntoResponse { /* ... */ } + /// + /// // `api_fallback` will be called for `/api/some-unknown-path` but not for + /// // `/some-unknown-path` as the path doesn't start with `/api` + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + pub fn fallback(mut self, svc: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone @@ -591,24 +728,38 @@ where + 'static, T::Future: Send + 'static, { - self.map(|first| or::Or { - first, - second: other, - }) + self.fallback = Some(Route(CloneBoxService::new(svc))); + self } - fn map(self, f: F) -> Router - where - F: FnOnce(Routes) -> T, - T: Service, Response = Response, Error = Infallible> - + Clone - + Send - + 'static, - T::Future: Send + 'static, - { - Router { - routes: Routes(CloneBoxService::new(f(self.routes))), - node: self.node, + #[inline] + fn call_route(&self, match_: matchit::Match<&RouteId>, mut req: Request) -> RouterFuture { + let id = *match_.value; + req.extensions_mut().insert(id); + + let params = match_ + .params + .iter() + .filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM)) + .map(|(key, value)| (key.to_string(), value.to_string())) + .collect::>(); + + if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { + UriStack::push(&mut req); + let new_uri = with_path(req.uri(), tail); + *req.uri_mut() = new_uri; + } + + insert_url_params(&mut req, params); + + let route = self + .routes + .get(&id) + .expect("no route for id. This is a bug in axum. Please file an issue") + .clone(); + + RouterFuture { + future: futures_util::future::Either::Left(route.oneshot(req)), } } } @@ -619,11 +770,11 @@ where { type Response = Response; type Error = Infallible; - type Future = RoutesFuture; + type Future = RouterFuture; #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.routes.poll_ready(cx) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } #[inline] @@ -634,27 +785,19 @@ where } let path = req.uri().path().to_string(); + if let Ok(match_) = self.node.at(&path) { - let id = *match_.value; - req.extensions_mut().insert(id); - - let params = match_ - .params - .iter() - .filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM)) - .map(|(key, value)| (key.to_string(), value.to_string())) - .collect::>(); - - if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { - UriStack::push(&mut req); - let new_uri = with_path(req.uri(), tail); - *req.uri_mut() = new_uri; + self.call_route(match_, req) + } else if let Some(fallback) = &self.fallback { + RouterFuture { + future: futures_util::future::Either::Left(fallback.clone().oneshot(req)), + } + } else { + let res = EmptyRouter::::not_found().call_sync(req); + RouterFuture { + future: futures_util::future::Either::Right(std::future::ready(Ok(res))), } - - insert_url_params(&mut req, params); } - - self.routes.call(req) } } @@ -670,12 +813,6 @@ impl UriStack { req.extensions_mut().insert(Self(vec![uri])); } } - - pub(crate) fn pop(req: &mut Request) -> Option { - req.extensions_mut() - .get_mut::() - .and_then(|stack| stack.0.pop()) - } } // we store the potential error here such that users can handle invalid path @@ -745,6 +882,15 @@ impl EmptyRouter { _marker: PhantomData, } } + + fn call_sync(&mut self, _req: Request) -> Response + where + B: Send + Sync + 'static, + { + let mut res = Response::new(crate::body::empty()); + *res.status_mut() = self.status; + res + } } impl Clone for EmptyRouter { @@ -774,99 +920,31 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, mut request: Request) -> Self::Future { - if self.status == StatusCode::METHOD_NOT_ALLOWED { - // we're inside a route but there was no method that matched - // so record that so we can override the status if no other - // routes match - request.extensions_mut().insert(NoMethodMatch); - } + fn call(&mut self, request: Request) -> Self::Future { + let res = self.call_sync(request); - if self.status == StatusCode::NOT_FOUND - && request.extensions().get::().is_some() - { - self.status = StatusCode::METHOD_NOT_ALLOWED; - } - - let mut res = Response::new(crate::body::empty()); - - res.extensions_mut().insert(FromEmptyRouter { request }); - - *res.status_mut() = self.status; EmptyRouterFuture { future: ready(Ok(res)), } } } -#[derive(Clone, Copy)] -struct NoMethodMatch; - -/// Response extension used by [`EmptyRouter`] to send the request back to [`Or`] so -/// the other service can be called. -/// -/// Without this we would loose ownership of the request when calling the first -/// service in [`Or`]. We also wouldn't be able to identify if the response came -/// from [`EmptyRouter`] and therefore can be discarded in [`Or`]. -struct FromEmptyRouter { - request: Request, -} - -#[derive(Debug, Clone)] -struct Route { - id: RouteId, - svc: S, - fallback: T, -} - -impl Service> for Route -where - S: Service, Response = Response, Error = Infallible> + Clone, - T: Service, Response = Response, Error = Infallible> + Clone, - B: Send + Sync + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = RouteFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - match req.extensions().get::() { - Some(id) => { - if self.id == *id { - RouteFuture::a(self.svc.clone().oneshot(req)) - } else { - RouteFuture::b(self.fallback.clone().oneshot(req)) - } - } - None => RouteFuture::b(self.fallback.clone().oneshot(req)), - } - } -} - /// A [`Service`] that has been nested inside a router at some path. /// /// Created with [`Router::nest`]. #[derive(Debug, Clone)] -struct Nested { - id: RouteId, +struct Nested { svc: S, - fallback: T, } -impl Service> for Nested +impl Service> for Nested where S: Service, Response = Response, Error = Infallible> + Clone, - T: Service, Response = Response, Error = Infallible> + Clone, B: Send + Sync + 'static, { type Response = Response; type Error = Infallible; - type Future = NestedFuture; + type Future = NestedFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -874,18 +952,9 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let future = match req.extensions().get::() { - Some(id) => { - if self.id == *id { - RouteFuture::a(self.svc.clone().oneshot(req)) - } else { - RouteFuture::b(self.fallback.clone().oneshot(req)) - } - } - None => RouteFuture::b(self.fallback.clone().oneshot(req)), - }; - - NestedFuture { inner: future } + NestedFuture { + inner: self.svc.clone().oneshot(req), + } } } @@ -954,24 +1023,24 @@ where /// How routes are stored inside a [`Router`]. /// /// You normally shouldn't need to care about this type. -pub struct Routes(CloneBoxService, Response, Infallible>); +pub struct Route(CloneBoxService, Response, Infallible>); -impl Clone for Routes { +impl Clone for Route { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl fmt::Debug for Routes { +impl fmt::Debug for Route { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Router").finish() + f.debug_struct("Route").finish() } } -impl Service> for Routes { +impl Service> for Route { type Response = Response; type Error = Infallible; - type Future = future::RoutesFuture; + type Future = RouteFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -980,12 +1049,51 @@ impl Service> for Routes { #[inline] fn call(&mut self, req: Request) -> Self::Future { - future::RoutesFuture { + RouteFuture { future: self.0.call(req), } } } +#[derive(Clone, Default)] +struct Node { + inner: matchit::Node, + paths: Vec<(String, RouteId)>, +} + +impl Node { + fn insert( + &mut self, + path: impl Into, + val: RouteId, + ) -> Result<(), matchit::InsertError> { + let path = path.into(); + self.inner.insert(&path, val)?; + self.paths.push((path, val)); + Ok(()) + } + + fn merge(&mut self, other: Node) -> Result<(), matchit::InsertError> { + for (path, id) in other.paths { + self.insert(path, id)?; + } + Ok(()) + } + + fn at<'n, 'p>( + &'n self, + path: &'p str, + ) -> Result, matchit::MatchError> { + self.inner.at(path) + } +} + +impl fmt::Debug for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Node").field("paths", &self.paths).finish() + } +} + #[cfg(test)] mod tests { use super::*; @@ -996,14 +1104,13 @@ mod tests { assert_send::>(); - assert_send::>(); - assert_sync::>(); + assert_send::>(); assert_send::>(); assert_sync::>(); - assert_send::>(); - assert_sync::>(); + assert_send::>(); + assert_sync::>(); assert_send::>(); assert_sync::>(); diff --git a/src/routing/or.rs b/src/routing/or.rs deleted file mode 100644 index 7bf61538..00000000 --- a/src/routing/or.rs +++ /dev/null @@ -1,128 +0,0 @@ -//! [`Or`] used to combine two services into one. - -use super::FromEmptyRouter; -use crate::body::BoxBody; -use futures_util::ready; -use http::{Request, Response}; -use pin_project_lite::pin_project; -use std::{ - convert::Infallible, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use tower::{util::Oneshot, ServiceExt}; -use tower_service::Service; - -/// [`tower::Service`] that is the combination of two routers. -/// -/// See [`Router::or`] for more details. -/// -/// [`Router::or`]: super::Router::or -#[derive(Debug, Clone, Copy)] -pub(crate) struct Or { - pub(super) first: A, - pub(super) second: B, -} - -#[test] -fn traits() { - use crate::tests::*; - assert_send::>(); - assert_sync::>(); -} - -impl Service> for Or -where - A: Service, Response = Response, Error = Infallible> + Clone, - B: Service, Response = Response, Error = Infallible> + Clone, - ReqBody: Send + Sync + 'static, - A: Send + 'static, - B: Send + 'static, - A::Future: Send + 'static, - B::Future: Send + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = ResponseFuture; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - ResponseFuture { - state: State::FirstFuture { - f: self.first.clone().oneshot(req), - }, - second: Some(self.second.clone()), - } - } -} - -pin_project! { - /// Response future for [`Or`]. - pub(crate) struct ResponseFuture - where - A: Service>, - B: Service>, - { - #[pin] - state: State, - second: Option, - } -} - -pin_project! { - #[project = StateProj] - enum State - where - A: Service>, - B: Service>, - { - FirstFuture { #[pin] f: Oneshot> }, - SecondFuture { - #[pin] - f: Oneshot>, - } - } -} - -impl Future for ResponseFuture -where - A: Service, Response = Response, Error = Infallible>, - B: Service, Response = Response, Error = Infallible>, - ReqBody: Send + Sync + 'static, -{ - type Output = Result, Infallible>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let mut this = self.as_mut().project(); - - let new_state = match this.state.as_mut().project() { - StateProj::FirstFuture { f } => { - let mut response = ready!(f.poll(cx)?); - - let req = if let Some(ext) = response - .extensions_mut() - .remove::>() - { - ext.request - } else { - return Poll::Ready(Ok(response)); - }; - - let second = this.second.take().expect("future polled after completion"); - - State::SecondFuture { - f: second.oneshot(req), - } - } - StateProj::SecondFuture { f } => return f.poll(cx), - }; - - this.state.set(new_state); - } - } -} diff --git a/src/tests/fallback.rs b/src/tests/fallback.rs new file mode 100644 index 00000000..c5708929 --- /dev/null +++ b/src/tests/fallback.rs @@ -0,0 +1,96 @@ +use super::*; +use crate::handler::Handler; + +#[tokio::test] +async fn basic() { + let app = Router::new() + .route("/foo", get(|| async {})) + .fallback((|| async { "fallback" }).into_service()); + + let client = TestClient::new(app); + + assert_eq!(client.get("/foo").send().await.status(), StatusCode::OK); + + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback"); +} + +#[tokio::test] +async fn nest() { + let app = Router::new() + .nest("/foo", Router::new().route("/bar", get(|| async {}))) + .fallback((|| async { "fallback" }).into_service()); + + let client = TestClient::new(app); + + assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK); + + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback"); +} + +#[tokio::test] +async fn nesting_with_fallback() { + let app = Router::new().nest( + "/foo", + Router::new() + .route("/bar", get(|| async {})) + .fallback((|| async { "fallback" }).into_service()), + ); + + let client = TestClient::new(app); + + assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK); + + // this shouldn't exist because the fallback is inside the nested router + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // this should work since we get into the nested router + let res = client.get("/foo/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback"); +} + +#[tokio::test] +async fn or() { + let one = Router::new().route("/one", get(|| async {})); + let two = Router::new().route("/two", get(|| async {})); + + let app = one + .merge(two) + .fallback((|| async { "fallback" }).into_service()); + + let client = TestClient::new(app); + + assert_eq!(client.get("/one").send().await.status(), StatusCode::OK); + assert_eq!(client.get("/two").send().await.status(), StatusCode::OK); + + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback"); +} + +#[tokio::test] +async fn fallback_on_or() { + let one = Router::new() + .route("/one", get(|| async {})) + .fallback((|| async { "fallback one" }).into_service()); + + let two = Router::new() + .route("/two", get(|| async {})) + .fallback((|| async { "fallback two" }).into_service()); + + let app = one.merge(two); + + let client = TestClient::new(app); + + assert_eq!(client.get("/one").send().await.status(), StatusCode::OK); + assert_eq!(client.get("/two").send().await.status(), StatusCode::OK); + + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback two"); +} diff --git a/src/tests/or.rs b/src/tests/merge.rs similarity index 80% rename from src/tests/or.rs rename to src/tests/merge.rs index 128f14e1..a91dcb57 100644 --- a/src/tests/or.rs +++ b/src/tests/merge.rs @@ -9,7 +9,7 @@ async fn basic() { .route("/foo", get(|| async {})) .route("/bar", get(|| async {})); let two = Router::new().route("/baz", get(|| async {})); - let app = one.or(two); + let app = one.merge(two); let client = TestClient::new(app); @@ -36,28 +36,28 @@ async fn multiple_ors_balanced_differently() { test( "one", one.clone() - .or(two.clone()) - .or(three.clone()) - .or(four.clone()), + .merge(two.clone()) + .merge(three.clone()) + .merge(four.clone()), ) .await; test( "two", one.clone() - .or(two.clone()) - .or(three.clone().or(four.clone())), + .merge(two.clone()) + .merge(three.clone().merge(four.clone())), ) .await; test( "three", one.clone() - .or(two.clone().or(three.clone()).or(four.clone())), + .merge(two.clone().merge(three.clone()).merge(four.clone())), ) .await; - test("four", one.or(two.or(three.or(four)))).await; + test("four", one.merge(two.merge(three.merge(four)))).await; async fn test(name: &str, app: S) where @@ -84,7 +84,7 @@ async fn nested_or() { let bar = Router::new().route("/bar", get(|| async { "bar" })); let baz = Router::new().route("/baz", get(|| async { "baz" })); - let bar_or_baz = bar.or(baz); + let bar_or_baz = bar.merge(baz); let client = TestClient::new(bar_or_baz.clone()); assert_eq!(client.get("/bar").send().await.text().await, "bar"); @@ -99,7 +99,7 @@ async fn nested_or() { async fn or_with_route_following() { let one = Router::new().route("/one", get(|| async { "one" })); let two = Router::new().route("/two", get(|| async { "two" })); - let app = one.or(two).route("/three", get(|| async { "three" })); + let app = one.merge(two).route("/three", get(|| async { "three" })); let client = TestClient::new(app); @@ -119,7 +119,7 @@ async fn layer() { let two = Router::new() .route("/bar", get(|| async {})) .layer(ConcurrencyLimitLayer::new(10)); - let app = one.or(two); + let app = one.merge(two); let client = TestClient::new(app); @@ -140,7 +140,7 @@ async fn layer_and_handle_error() { .layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT)) .layer(TimeoutLayer::new(Duration::from_millis(10))), ); - let app = one.or(two); + let app = one.merge(two); let client = TestClient::new(app); @@ -152,7 +152,7 @@ async fn layer_and_handle_error() { async fn nesting() { let one = Router::new().route("/foo", get(|| async {})); let two = Router::new().nest("/bar", Router::new().route("/baz", get(|| async {}))); - let app = one.or(two); + let app = one.merge(two); let client = TestClient::new(app); @@ -164,7 +164,7 @@ async fn nesting() { async fn boxed() { let one = Router::new().route("/foo", get(|| async {})); let two = Router::new().route("/bar", get(|| async {})); - let app = one.or(two); + let app = one.merge(two); let client = TestClient::new(app); @@ -176,12 +176,12 @@ async fn boxed() { async fn many_ors() { let app = Router::new() .route("/r1", get(|| async {})) - .or(Router::new().route("/r2", get(|| async {}))) - .or(Router::new().route("/r3", get(|| async {}))) - .or(Router::new().route("/r4", get(|| async {}))) - .or(Router::new().route("/r5", get(|| async {}))) - .or(Router::new().route("/r6", get(|| async {}))) - .or(Router::new().route("/r7", get(|| async {}))); + .merge(Router::new().route("/r2", get(|| async {}))) + .merge(Router::new().route("/r3", get(|| async {}))) + .merge(Router::new().route("/r4", get(|| async {}))) + .merge(Router::new().route("/r5", get(|| async {}))) + .merge(Router::new().route("/r6", get(|| async {}))) + .merge(Router::new().route("/r7", get(|| async {}))); let client = TestClient::new(app); @@ -205,7 +205,7 @@ async fn services() { Ok::<_, Infallible>(Response::new(Body::empty())) })), ) - .or(Router::new().route( + .merge(Router::new().route( "/bar", get(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::empty())) @@ -238,7 +238,7 @@ async fn nesting_and_seeing_the_right_uri() { let one = Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))); let two = Router::new().route("/foo", get(all_the_uris)); - let client = TestClient::new(one.or(two)); + let client = TestClient::new(one.merge(two)); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -271,7 +271,7 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { ); let two = Router::new().route("/foo", get(all_the_uris)); - let client = TestClient::new(one.or(two)); + let client = TestClient::new(one.merge(two)); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -299,44 +299,44 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { #[tokio::test] async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { let one = Router::new().nest( - "/foo", + "/one", Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), ); - let two = Router::new().nest("/foo", Router::new().route("/qux", get(all_the_uris))); - let three = Router::new().route("/foo", get(all_the_uris)); + let two = Router::new().nest("/two", Router::new().route("/qux", get(all_the_uris))); + let three = Router::new().route("/three", get(all_the_uris)); - let client = TestClient::new(one.or(two).or(three)); + let client = TestClient::new(one.merge(two).merge(three)); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/one/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/baz", "request_uri": "/baz", - "original_uri": "/foo/bar/baz", + "original_uri": "/one/bar/baz", }) ); - let res = client.get("/foo/qux").send().await; + let res = client.get("/two/qux").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/qux", "request_uri": "/qux", - "original_uri": "/foo/qux", + "original_uri": "/two/qux", }) ); - let res = client.get("/foo").send().await; + let res = client.get("/three").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ - "uri": "/foo", - "request_uri": "/foo", - "original_uri": "/foo", + "uri": "/three", + "request_uri": "/three", + "original_uri": "/three", }) ); } @@ -344,32 +344,32 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { #[tokio::test] async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { let one = Router::new().nest( - "/foo", - Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), + "/one", + Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))), ); - let two = Router::new().route("/foo/bar", get(all_the_uris)); + let two = Router::new().route("/two/foo", get(all_the_uris)); - let client = TestClient::new(one.or(two)); + let client = TestClient::new(one.merge(two)); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/one/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ - "uri": "/baz", - "request_uri": "/baz", - "original_uri": "/foo/bar/baz", + "uri": "/bar", + "request_uri": "/bar", + "original_uri": "/one/foo/bar", }) ); - let res = client.get("/foo/bar").send().await; + let res = client.get("/two/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ - "uri": "/foo/bar", - "request_uri": "/foo/bar", - "original_uri": "/foo/bar", + "uri": "/two/foo", + "request_uri": "/two/foo", + "original_uri": "/two/foo", }) ); } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index c1813cb4..cf9f3bff 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -31,11 +31,12 @@ use tower_service::Service; pub(crate) use helpers::*; +mod fallback; mod get_to_head; mod handle_error; mod helpers; +mod merge; mod nest; -mod or; #[tokio::test] async fn hello_world() {