Add `MatchedPath` extractor (#412)

Fixes #386
This commit is contained in:
David Pedersen 2021-10-25 23:38:29 +02:00 committed by GitHub
parent e43bdf0ecf
commit 02a035fb14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 139 additions and 3 deletions

View File

@ -136,6 +136,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
without trailing a slash.
- **breaking:** `EmptyRouter` has been renamed to `MethodNotAllowed` as its only
used in method routers and not in path routers (`Router`)
- **added:** Add `extract::MatchedPath` for accessing path in router that
matched request ([#412])
[#339]: https://github.com/tokio-rs/axum/pull/339
[#286]: https://github.com/tokio-rs/axum/pull/286
@ -147,6 +149,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#404]: https://github.com/tokio-rs/axum/pull/404
[#405]: https://github.com/tokio-rs/axum/pull/405
[#408]: https://github.com/tokio-rs/axum/pull/408
[#412]: https://github.com/tokio-rs/axum/pull/412
# 0.2.8 (07. October, 2021)

View File

@ -60,6 +60,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] }
tokio-stream = "0.1"
tracing = "0.1"
uuid = { version = "0.8", features = ["serde", "v4"] }
[dev-dependencies.tower]

View File

@ -0,0 +1,86 @@
use super::{rejection::*, FromRequest, RequestParts};
use async_trait::async_trait;
use std::sync::Arc;
/// Access the path in the router that matches the request.
///
/// ```
/// use axum::{
/// Router,
/// extract::MatchedPath,
/// routing::get,
/// };
///
/// let app = Router::new().route(
/// "/users/:id",
/// get(|path: MatchedPath| async move {
/// let path = path.as_str();
/// // `path` will be "/users/:id"
/// })
/// );
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// `MatchedPath` can also be accessed from middleware via request extensions.
/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
/// create a span that contains the matched path:
///
/// ```
/// use axum::{
/// Router,
/// extract::MatchedPath,
/// http::Request,
/// routing::get,
/// };
/// use tower_http::trace::TraceLayer;
///
/// let app = Router::new()
/// .route("/users/:id", get(|| async { /* ... */ }))
/// .layer(
/// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
/// let path = if let Some(path) = req.extensions().get::<MatchedPath>() {
/// path.as_str()
/// } else {
/// req.uri().path()
/// };
/// tracing::info_span!("http-request", %path)
/// }),
/// );
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[derive(Clone, Debug)]
pub struct MatchedPath(pub(crate) Arc<str>);
impl MatchedPath {
/// Returns a `str` representation of the path.
pub fn as_str(&self) -> &str {
&*self.0
}
}
#[async_trait]
impl<B> FromRequest<B> for MatchedPath
where
B: Send,
{
type Rejection = MatchedPathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let extensions =
req.extensions()
.ok_or(MatchedPathRejection::ExtensionsAlreadyExtracted(
ExtensionsAlreadyExtracted,
))?;
let matched_path = extensions
.get::<Self>()
.ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))?
.clone();
Ok(matched_path)
}
}

View File

@ -172,6 +172,7 @@ pub mod ws;
mod content_length_limit;
mod extension;
mod form;
mod matched_path;
mod path;
mod query;
mod raw_query;
@ -186,6 +187,7 @@ pub use self::{
extension::Extension,
extractor_middleware::extractor_middleware,
form::Form,
matched_path::MatchedPath,
path::Path,
query::Query,
raw_query::RawQuery,

View File

@ -273,6 +273,23 @@ composite_rejection! {
}
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "No matched path found"]
/// Rejection if no matched path could be found.
///
/// See [`MatchedPath`](super::MatchedPath) for more details.
pub struct MatchedPathMissing;
}
composite_rejection! {
/// Rejection used for [`MatchedPath`](super::MatchedPath).
pub enum MatchedPathRejection {
ExtensionsAlreadyExtracted,
MatchedPathMissing,
}
}
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
///
/// Contains one variant for each way the

View File

@ -738,6 +738,11 @@ where
let id = *match_.value;
req.extensions_mut().insert(id);
if let Some(matched_path) = self.node.paths.get(&id) {
req.extensions_mut()
.insert(crate::extract::MatchedPath(matched_path.clone()));
}
let params = match_
.params
.iter()
@ -1059,7 +1064,7 @@ impl<B> Service<Request<B>> for Route<B> {
#[derive(Clone, Default)]
struct Node {
inner: matchit::Node<RouteId>,
paths: Vec<(Arc<str>, RouteId)>,
paths: HashMap<RouteId, Arc<str>>,
}
impl Node {
@ -1070,12 +1075,12 @@ impl Node {
) -> Result<(), matchit::InsertError> {
let path = path.into();
self.inner.insert(&path, val)?;
self.paths.push((path.into(), val));
self.paths.insert(val, path.into());
Ok(())
}
fn merge(&mut self, other: Node) -> Result<(), matchit::InsertError> {
for (path, id) in other.paths {
for (id, path) in other.paths {
self.insert(&*path, id)?;
}
Ok(())

View File

@ -1,6 +1,7 @@
#![allow(clippy::blacklisted_name)]
use crate::error_handling::HandleErrorLayer;
use crate::extract::MatchedPath;
use crate::BoxError;
use crate::{
extract::{self, Path},
@ -27,6 +28,7 @@ use std::{
};
use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder};
use tower_http::auth::RequireAuthorizationLayer;
use tower_http::trace::TraceLayer;
use tower_service::Service;
pub(crate) use helpers::*;
@ -618,6 +620,26 @@ async fn with_and_without_trailing_slash() {
assert_eq!(res.text().await, "without tsr");
}
#[tokio::test]
async fn access_matched_path() {
let app = Router::new()
.route(
"/:key",
get(|path: MatchedPath| async move { path.as_str().to_string() }),
)
.layer(
TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
let path = req.extensions().get::<MatchedPath>().unwrap().as_str();
tracing::info_span!("http-request", %path)
}),
);
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "/:key");
}
pub(crate) fn assert_send<T: Send>() {}
pub(crate) fn assert_sync<T: Sync>() {}
pub(crate) fn assert_unpin<T: Unpin>() {}