From 822db3b1af766b7372541e7a490ab21a8e748e10 Mon Sep 17 00:00:00 2001 From: Sabrina Jewson Date: Sun, 6 Oct 2024 20:09:06 +0100 Subject: [PATCH] Add `MethodFilter::CONNECT` (#2961) --- axum-extra/CHANGELOG.md | 2 + axum-extra/src/routing/mod.rs | 23 +++++++++ axum/CHANGELOG.md | 12 ++++- axum/src/routing/method_filter.rs | 34 +++++++++++-- axum/src/routing/method_routing.rs | 79 +++++++++++++++++++++++++++++- axum/src/routing/mod.rs | 6 +-- 6 files changed, 146 insertions(+), 10 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 714983a5..10d8a930 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning]. # Unreleased +- **added:** Add `RouterExt::typed_connect` ([#2961]) - **added:** Add `json!` for easy construction of JSON responses ([#2962]) +[#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2962]: https://github.com/tokio-rs/axum/pull/2962 # 0.10.0 diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 9d9aa0cb..8fdfac81 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -131,6 +131,19 @@ pub trait RouterExt: sealed::Sealed { T: SecondElementIs

+ 'static, P: TypedPath; + /// Add a typed `CONNECT` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath; + /// Add another route to the router with an additional "trailing slash redirect" route. /// /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a @@ -255,6 +268,16 @@ where self.route(P::PATH, axum::routing::trace(handler)) } + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::connect(handler)) + } + #[track_caller] fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self where diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 36078115..1e6dcc9d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# Unreleased + +- **added:** Add support for WebSockets over HTTP/2 ([#2894]). + They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)` +- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` + and `MethodRouter::connect[_service]` ([#2961]) + +[#2984]: https://github.com/tokio-rs/axum/pull/2984 +[#2961]: https://github.com/tokio-rs/axum/pull/2961 + # 0.8.0 ## alpha.1 @@ -15,8 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many` to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645]) - **change:** Update minimum rust version to 1.75 ([#2943]) -- **added:** Add support WebSockets over HTTP/2. - They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`. [#2473]: https://github.com/tokio-rs/axum/pull/2473 [#2645]: https://github.com/tokio-rs/axum/pull/2645 diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs index 1cea4235..040783ec 100644 --- a/axum/src/routing/method_filter.rs +++ b/axum/src/routing/method_filter.rs @@ -9,6 +9,24 @@ use std::{ pub struct MethodFilter(u16); impl MethodFilter { + /// Match `CONNECT` requests. + /// + /// This is useful for implementing HTTP/2's [extended CONNECT method], + /// in which the `:protocol` pseudoheader is read + /// (using [`hyper::ext::Protocol`]) + /// and the connection upgraded to a bidirectional byte stream + /// (using [`hyper::upgrade::on`]). + /// + /// As seen in the [HTTP Upgrade Token Registry], + /// common uses include WebSockets and proxying UDP or IP – + /// though note that when using [`WebSocketUpgrade`] + /// it's more useful to use [`any`](crate::routing::any) + /// as HTTP/1.1 WebSockets need to support `GET`. + /// + /// [extended CONNECT]: https://www.rfc-editor.org/rfc/rfc8441.html#section-4 + /// [HTTP Upgrade Token Registry]: https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml + /// [`WebSocketUpgrade`]: crate::extract::WebSocketUpgrade + pub const CONNECT: Self = Self::from_bits(0b0_0000_0001); /// Match `DELETE` requests. pub const DELETE: Self = Self::from_bits(0b0_0000_0010); /// Match `GET` requests. @@ -71,6 +89,7 @@ impl TryFrom for MethodFilter { fn try_from(m: Method) -> Result { match m { + Method::CONNECT => Ok(MethodFilter::CONNECT), Method::DELETE => Ok(MethodFilter::DELETE), Method::GET => Ok(MethodFilter::GET), Method::HEAD => Ok(MethodFilter::HEAD), @@ -90,6 +109,11 @@ mod tests { #[test] fn from_http_method() { + assert_eq!( + MethodFilter::try_from(Method::CONNECT).unwrap(), + MethodFilter::CONNECT + ); + assert_eq!( MethodFilter::try_from(Method::DELETE).unwrap(), MethodFilter::DELETE @@ -130,9 +154,11 @@ mod tests { MethodFilter::TRACE ); - assert!(MethodFilter::try_from(http::Method::CONNECT) - .unwrap_err() - .to_string() - .contains("CONNECT")); + assert!( + MethodFilter::try_from(http::Method::from_bytes(b"CUSTOM").unwrap()) + .unwrap_err() + .to_string() + .contains("CUSTOM") + ); } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 3b62f728..b4a86501 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -59,6 +59,19 @@ macro_rules! top_level_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_service_fn!( + /// Route `CONNECT` requests to the given service. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -118,6 +131,19 @@ macro_rules! top_level_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_handler_fn!( + /// Route `CONNECT` requests to the given handler. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -187,6 +213,19 @@ macro_rules! chained_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_service_fn!( + /// Chain an additional service that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -250,6 +289,19 @@ macro_rules! chained_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_handler_fn!( + /// Chain an additional handler that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -279,6 +331,7 @@ macro_rules! chained_handler_fn { }; } +top_level_service_fn!(connect_service, CONNECT); top_level_service_fn!(delete_service, DELETE); top_level_service_fn!(get_service, GET); top_level_service_fn!(head_service, HEAD); @@ -382,6 +435,7 @@ where .skip_allow_header() } +top_level_handler_fn!(connect, CONNECT); top_level_handler_fn!(delete, DELETE); top_level_handler_fn!(get, GET); top_level_handler_fn!(head, HEAD); @@ -498,6 +552,7 @@ pub struct MethodRouter { post: MethodEndpoint, put: MethodEndpoint, trace: MethodEndpoint, + connect: MethodEndpoint, fallback: Fallback, allow_header: AllowHeader, } @@ -539,6 +594,7 @@ impl fmt::Debug for MethodRouter { .field("post", &self.post) .field("put", &self.put) .field("trace", &self.trace) + .field("connect", &self.connect) .field("fallback", &self.fallback) .field("allow_header", &self.allow_header) .finish() @@ -583,6 +639,7 @@ where ) } + chained_handler_fn!(connect, CONNECT); chained_handler_fn!(delete, DELETE); chained_handler_fn!(get, GET); chained_handler_fn!(head, HEAD); @@ -690,6 +747,7 @@ where post: MethodEndpoint::None, put: MethodEndpoint::None, trace: MethodEndpoint::None, + connect: MethodEndpoint::None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), } @@ -706,6 +764,7 @@ where post: self.post.with_state(&state), put: self.put.with_state(&state), trace: self.trace.with_state(&state), + connect: self.connect.with_state(&state), allow_header: self.allow_header, fallback: self.fallback.with_state(state), } @@ -854,9 +913,20 @@ where &["DELETE"], ); + set_endpoint( + "CONNECT", + &mut self.options, + &endpoint, + filter, + MethodFilter::CONNECT, + &mut self.allow_header, + &["CONNECT"], + ); + self } + chained_service_fn!(connect_service, CONNECT); chained_service_fn!(delete_service, DELETE); chained_service_fn!(get_service, GET); chained_service_fn!(head_service, HEAD); @@ -900,6 +970,7 @@ where post: self.post.map(layer_fn.clone()), put: self.put.map(layer_fn.clone()), trace: self.trace.map(layer_fn.clone()), + connect: self.connect.map(layer_fn.clone()), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, } @@ -924,6 +995,7 @@ where && self.post.is_none() && self.put.is_none() && self.trace.is_none() + && self.connect.is_none() { panic!( "Adding a route_layer before any routes is a no-op. \ @@ -944,7 +1016,8 @@ where self.patch = self.patch.map(layer_fn.clone()); self.post = self.post.map(layer_fn.clone()); self.put = self.put.map(layer_fn.clone()); - self.trace = self.trace.map(layer_fn); + self.trace = self.trace.map(layer_fn.clone()); + self.connect = self.connect.map(layer_fn); self } @@ -985,6 +1058,7 @@ where self.post = merge_inner(path, "POST", self.post, other.post); self.put = merge_inner(path, "PUT", self.put, other.put); self.trace = merge_inner(path, "TRACE", self.trace, other.trace); + self.connect = merge_inner(path, "CONNECT", self.connect, other.connect); self.fallback = self .fallback @@ -1058,6 +1132,7 @@ where post, put, trace, + connect, fallback, allow_header, } = self; @@ -1071,6 +1146,7 @@ where call!(req, method, PUT, put); call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); + call!(req, method, CONNECT, connect); let future = fallback.clone().call_with_state(req, state); @@ -1113,6 +1189,7 @@ impl Clone for MethodRouter { post: self.post.clone(), put: self.put.clone(), trace: self.trace.clone(), + connect: self.connect.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 9987dd4f..54dfbc77 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -40,9 +40,9 @@ mod tests; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::method_routing::{ - any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, - options, options_service, patch, patch_service, post, post_service, put, put_service, trace, - trace_service, MethodRouter, + any, any_service, connect, connect_service, delete, delete_service, get, get_service, head, + head_service, on, on_service, options, options_service, patch, patch_service, post, + post_service, put, put_service, trace, trace_service, MethodRouter, }; macro_rules! panic_on_err {