mirror of https://github.com/tokio-rs/axum
Work around for http2 hang with `Or` (#199)
This is a nasty hack that works around https://github.com/hyperium/hyper/issues/2621. Fixes https://github.com/tokio-rs/axum/issues/191
This commit is contained in:
parent
97c140cdf7
commit
93cdfe8c5f
|
@ -16,7 +16,7 @@ async fn main() {
|
|||
tracing_subscriber::fmt::init();
|
||||
|
||||
// build our application with a route
|
||||
let app = route("/", get(handler));
|
||||
let app = route("/foo", get(handler)).or(route("/bar", get(handler)));
|
||||
|
||||
// run it
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
|
|
@ -547,7 +547,11 @@ where
|
|||
|
||||
fn call(&mut self, request: Request<B>) -> Self::Future {
|
||||
let mut res = Response::new(crate::body::empty());
|
||||
res.extensions_mut().insert(FromEmptyRouter { request });
|
||||
|
||||
if request.extensions().get::<OrDepth>().is_some() {
|
||||
res.extensions_mut().insert(FromEmptyRouter { request });
|
||||
}
|
||||
|
||||
*res.status_mut() = self.status;
|
||||
EmptyRouterFuture {
|
||||
future: futures_util::future::ok(res),
|
||||
|
@ -565,6 +569,39 @@ struct FromEmptyRouter<B> {
|
|||
request: Request<B>,
|
||||
}
|
||||
|
||||
/// We need to track whether we're inside an `Or` or not, and only if we then
|
||||
/// should we save the request into the response extensions.
|
||||
///
|
||||
/// This is to work around https://github.com/hyperium/hyper/issues/2621.
|
||||
///
|
||||
/// Since ours can be nested we have to track the depth to know when we're
|
||||
/// leaving the top most `Or`.
|
||||
///
|
||||
/// Hopefully when https://github.com/hyperium/hyper/issues/2621 is resolved we
|
||||
/// can remove this nasty hack.
|
||||
#[derive(Debug)]
|
||||
struct OrDepth(usize);
|
||||
|
||||
impl OrDepth {
|
||||
fn new() -> Self {
|
||||
Self(1)
|
||||
}
|
||||
|
||||
fn increment(&mut self) {
|
||||
self.0 += 1;
|
||||
}
|
||||
|
||||
fn decrement(&mut self) {
|
||||
self.0 -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<usize> for &mut OrDepth {
|
||||
fn eq(&self, other: &usize) -> bool {
|
||||
self.0 == *other
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct PathPattern(Arc<Inner>);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! [`Or`] used to combine two services into one.
|
||||
|
||||
use super::{FromEmptyRouter, RoutingDsl};
|
||||
use super::{FromEmptyRouter, OrDepth, RoutingDsl};
|
||||
use crate::body::BoxBody;
|
||||
use futures_util::ready;
|
||||
use http::{Request, Response};
|
||||
|
@ -46,7 +46,13 @@ where
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||
count.increment();
|
||||
} else {
|
||||
req.extensions_mut().insert(OrDepth::new());
|
||||
}
|
||||
|
||||
ResponseFuture {
|
||||
state: State::FirstFuture {
|
||||
f: self.first.clone().oneshot(req),
|
||||
|
@ -100,7 +106,7 @@ where
|
|||
StateProj::FirstFuture { f } => {
|
||||
let mut response = ready!(f.poll(cx)?);
|
||||
|
||||
let req = if let Some(ext) = response
|
||||
let mut req = if let Some(ext) = response
|
||||
.extensions_mut()
|
||||
.remove::<FromEmptyRouter<ReqBody>>()
|
||||
{
|
||||
|
@ -109,6 +115,18 @@ where
|
|||
return Poll::Ready(Ok(response));
|
||||
};
|
||||
|
||||
let mut leaving_outermost_or = false;
|
||||
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||
if depth == 1 {
|
||||
leaving_outermost_or = true;
|
||||
} else {
|
||||
depth.decrement();
|
||||
}
|
||||
}
|
||||
if leaving_outermost_or {
|
||||
req.extensions_mut().remove::<OrDepth>();
|
||||
}
|
||||
|
||||
let second = this.second.take().expect("future polled after completion");
|
||||
|
||||
State::SecondFuture {
|
||||
|
|
115
src/tests/or.rs
115
src/tests/or.rs
|
@ -41,6 +41,121 @@ async fn basic() {
|
|||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_ors_balanced_differently() {
|
||||
let one = route("/one", get(|| async { "one" }));
|
||||
let two = route("/two", get(|| async { "two" }));
|
||||
let three = route("/three", get(|| async { "three" }));
|
||||
let four = route("/four", get(|| async { "four" }));
|
||||
|
||||
test(
|
||||
"one",
|
||||
one.clone()
|
||||
.or(two.clone())
|
||||
.or(three.clone())
|
||||
.or(four.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
test(
|
||||
"two",
|
||||
one.clone()
|
||||
.or(two.clone())
|
||||
.or(three.clone().or(four.clone())),
|
||||
)
|
||||
.await;
|
||||
|
||||
test(
|
||||
"three",
|
||||
one.clone()
|
||||
.or(two.clone().or(three.clone()).or(four.clone())),
|
||||
)
|
||||
.await;
|
||||
|
||||
test("four", one.or(two.or(three.or(four)))).await;
|
||||
|
||||
async fn test<S, ResBody>(name: &str, app: S)
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||
ResBody: http_body::Body + Send + 'static,
|
||||
ResBody::Data: Send,
|
||||
ResBody::Error: Into<BoxError>,
|
||||
S::Future: Send,
|
||||
S::Error: Into<BoxError>,
|
||||
{
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
for n in ["one", "two", "three", "four"].iter() {
|
||||
println!("running: {} / {}", name, n);
|
||||
let res = client
|
||||
.get(format!("http://{}/{}", addr, n))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), *n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn or_nested_inside_other_thing() {
|
||||
let inner = route("/bar", get(|| async {})).or(route("/baz", get(|| async {})));
|
||||
let app = nest("/foo", inner);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/baz", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn or_with_route_following() {
|
||||
let one = route("/one", get(|| async { "one" }));
|
||||
let two = route("/two", get(|| async { "two" }));
|
||||
let app = one.or(two).route("/three", get(|| async { "three" }));
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/one", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/two", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/three", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn layer() {
|
||||
let one = route("/foo", get(|| async {}));
|
||||
|
|
Loading…
Reference in New Issue