From 39cc596e45ec1c48eae9b7184c187fd7d1b5be5f Mon Sep 17 00:00:00 2001 From: Mikhail Antoshkin <51194616+mikhailantoshkin@users.noreply.github.com> Date: Sat, 18 Nov 2023 21:38:30 +0900 Subject: [PATCH] Add OptionalQuery extractor (#2310) Co-authored-by: David Pedersen --- axum-extra/CHANGELOG.md | 2 + axum-extra/src/extract/mod.rs | 2 +- axum-extra/src/extract/query.rs | 203 +++++++++++++++++++++++++++++++- 3 files changed, 205 insertions(+), 2 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 8942f261..a3d22dff 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning]. # Unreleased +- **added:** `OptionalQuery` extractor ([#2310]) - **added:** `TypedHeader` which used to be in `axum` ([#1850]) - **added:** `Clone` implementation for `ErasedJson` ([#2142]) - **breaking:** Update to prost 0.12. Used for the `Protobuf` extractor @@ -14,6 +15,7 @@ and this project adheres to [Semantic Versioning]. [#1850]: https://github.com/tokio-rs/axum/pull/1850 [#2142]: https://github.com/tokio-rs/axum/pull/2142 +[#2310]: https://github.com/tokio-rs/axum/pull/2310 # 0.7.4 (18. April, 2023) diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index c7946413..8435fc84 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -31,7 +31,7 @@ pub use self::cookie::SignedCookieJar; pub use self::form::{Form, FormRejection}; #[cfg(feature = "query")] -pub use self::query::{Query, QueryRejection}; +pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection}; #[cfg(feature = "multipart")] pub use self::multipart::Multipart; diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index b4f5bebd..bdeaf78e 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -112,6 +112,124 @@ impl std::error::Error for QueryRejection { } } +/// Extractor that deserializes query strings into `None` if no query parameters are present. +/// Otherwise behaviour is identical to [`Query`] +/// +/// `T` is expected to implement [`serde::Deserialize`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{routing::get, Router}; +/// use axum_extra::extract::OptionalQuery; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Pagination { +/// page: usize, +/// per_page: usize, +/// } +/// +/// // This will parse query strings like `?page=2&per_page=30` into `Some(Pagination)` and +/// // empty query string into `None` +/// async fn list_things(OptionalQuery(pagination): OptionalQuery) { +/// match pagination { +/// Some(Pagination{ page, per_page }) => { /* return specified page */ }, +/// None => { /* return fist page */ } +/// } +/// // ... +/// } +/// +/// let app = Router::new().route("/list_things", get(list_things)); +/// # let _: Router = app; +/// ``` +/// +/// If the query string cannot be parsed it will reject the request with a `400 +/// Bad Request` response. +/// +/// For handling values being empty vs missing see the [query-params-with-empty-strings][example] +/// example. +/// +/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs +#[cfg_attr(docsrs, doc(cfg(feature = "query")))] +#[derive(Debug, Clone, Copy, Default)] +pub struct OptionalQuery(pub Option); + +#[async_trait] +impl FromRequestParts for OptionalQuery +where + T: DeserializeOwned, + S: Send + Sync, +{ + type Rejection = OptionalQueryRejection; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if let Some(query) = parts.uri.query() { + let value = serde_html_form::from_str(query).map_err(|err| { + OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err)) + })?; + Ok(OptionalQuery(Some(value))) + } else { + Ok(OptionalQuery(None)) + } + } +} + +impl std::ops::Deref for OptionalQuery { + type Target = Option; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for OptionalQuery { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Rejection used for [`OptionalQuery`]. +/// +/// Contains one variant for each way the [`OptionalQuery`] extractor can fail. +#[derive(Debug)] +#[non_exhaustive] +#[cfg(feature = "query")] +pub enum OptionalQueryRejection { + #[allow(missing_docs)] + FailedToDeserializeQueryString(Error), +} + +impl IntoResponse for OptionalQueryRejection { + fn into_response(self) -> Response { + match self { + Self::FailedToDeserializeQueryString(inner) => ( + StatusCode::BAD_REQUEST, + format!("Failed to deserialize query string: {inner}"), + ) + .into_response(), + } + } +} + +impl fmt::Display for OptionalQueryRejection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::FailedToDeserializeQueryString(inner) => inner.fmt(f), + } + } +} + +impl std::error::Error for OptionalQueryRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::FailedToDeserializeQueryString(inner) => Some(inner), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -121,7 +239,7 @@ mod tests { use serde::Deserialize; #[tokio::test] - async fn supports_multiple_values() { + async fn query_supports_multiple_values() { #[derive(Deserialize)] struct Data { #[serde(rename = "value")] @@ -145,4 +263,87 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one,two"); } + + #[tokio::test] + async fn optional_query_supports_multiple_values() { + #[derive(Deserialize)] + struct Data { + #[serde(rename = "value")] + values: Vec, + } + + let app = Router::new().route( + "/", + post(|OptionalQuery(data): OptionalQuery| async move { + data.map(|Data { values }| values.join(",")) + .unwrap_or("None".to_owned()) + }), + ); + + let client = TestClient::new(app); + + let res = client + .post("/?value=one&value=two") + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body("") + .send() + .await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "one,two"); + } + + #[tokio::test] + async fn optional_query_deserializes_no_parameters_into_none() { + #[derive(Deserialize)] + struct Data { + value: String, + } + + let app = Router::new().route( + "/", + post(|OptionalQuery(data): OptionalQuery| async move { + match data { + None => "None".into(), + Some(data) => data.value, + } + }), + ); + + let client = TestClient::new(app); + + let res = client.post("/").body("").send().await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "None"); + } + + #[tokio::test] + async fn optional_query_preserves_parsing_errors() { + #[derive(Deserialize)] + struct Data { + value: String, + } + + let app = Router::new().route( + "/", + post(|OptionalQuery(data): OptionalQuery| async move { + match data { + None => "None".into(), + Some(data) => data.value, + } + }), + ); + + let client = TestClient::new(app); + + let res = client + .post("/?other=something") + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body("") + .send() + .await; + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } }