Work around for http2 hang with `Or` (#199)

This is a nasty hack that works around
https://github.com/hyperium/hyper/issues/2621.

Fixes https://github.com/tokio-rs/axum/issues/191
This commit is contained in:
David Pedersen 2021-08-17 19:00:24 +02:00 committed by GitHub
parent 97c140cdf7
commit 93cdfe8c5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 175 additions and 5 deletions

View File

@ -16,7 +16,7 @@ async fn main() {
tracing_subscriber::fmt::init();
// build our application with a route
let app = route("/", get(handler));
let app = route("/foo", get(handler)).or(route("/bar", get(handler)));
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View File

@ -547,7 +547,11 @@ where
fn call(&mut self, request: Request<B>) -> Self::Future {
let mut res = Response::new(crate::body::empty());
res.extensions_mut().insert(FromEmptyRouter { request });
if request.extensions().get::<OrDepth>().is_some() {
res.extensions_mut().insert(FromEmptyRouter { request });
}
*res.status_mut() = self.status;
EmptyRouterFuture {
future: futures_util::future::ok(res),
@ -565,6 +569,39 @@ struct FromEmptyRouter<B> {
request: Request<B>,
}
/// We need to track whether we're inside an `Or` or not, and only if we then
/// should we save the request into the response extensions.
///
/// This is to work around https://github.com/hyperium/hyper/issues/2621.
///
/// Since ours can be nested we have to track the depth to know when we're
/// leaving the top most `Or`.
///
/// Hopefully when https://github.com/hyperium/hyper/issues/2621 is resolved we
/// can remove this nasty hack.
#[derive(Debug)]
struct OrDepth(usize);
impl OrDepth {
fn new() -> Self {
Self(1)
}
fn increment(&mut self) {
self.0 += 1;
}
fn decrement(&mut self) {
self.0 -= 1;
}
}
impl PartialEq<usize> for &mut OrDepth {
fn eq(&self, other: &usize) -> bool {
self.0 == *other
}
}
#[derive(Debug, Clone)]
pub(crate) struct PathPattern(Arc<Inner>);

View File

@ -1,6 +1,6 @@
//! [`Or`] used to combine two services into one.
use super::{FromEmptyRouter, RoutingDsl};
use super::{FromEmptyRouter, OrDepth, RoutingDsl};
use crate::body::BoxBody;
use futures_util::ready;
use http::{Request, Response};
@ -46,7 +46,13 @@ where
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
count.increment();
} else {
req.extensions_mut().insert(OrDepth::new());
}
ResponseFuture {
state: State::FirstFuture {
f: self.first.clone().oneshot(req),
@ -100,7 +106,7 @@ where
StateProj::FirstFuture { f } => {
let mut response = ready!(f.poll(cx)?);
let req = if let Some(ext) = response
let mut req = if let Some(ext) = response
.extensions_mut()
.remove::<FromEmptyRouter<ReqBody>>()
{
@ -109,6 +115,18 @@ where
return Poll::Ready(Ok(response));
};
let mut leaving_outermost_or = false;
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
if depth == 1 {
leaving_outermost_or = true;
} else {
depth.decrement();
}
}
if leaving_outermost_or {
req.extensions_mut().remove::<OrDepth>();
}
let second = this.second.take().expect("future polled after completion");
State::SecondFuture {

View File

@ -41,6 +41,121 @@ async fn basic() {
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn multiple_ors_balanced_differently() {
let one = route("/one", get(|| async { "one" }));
let two = route("/two", get(|| async { "two" }));
let three = route("/three", get(|| async { "three" }));
let four = route("/four", get(|| async { "four" }));
test(
"one",
one.clone()
.or(two.clone())
.or(three.clone())
.or(four.clone()),
)
.await;
test(
"two",
one.clone()
.or(two.clone())
.or(three.clone().or(four.clone())),
)
.await;
test(
"three",
one.clone()
.or(two.clone().or(three.clone()).or(four.clone())),
)
.await;
test("four", one.or(two.or(three.or(four)))).await;
async fn test<S, ResBody>(name: &str, app: S)
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: http_body::Body + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<BoxError>,
S::Future: Send,
S::Error: Into<BoxError>,
{
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
for n in ["one", "two", "three", "four"].iter() {
println!("running: {} / {}", name, n);
let res = client
.get(format!("http://{}/{}", addr, n))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), *n);
}
}
}
#[tokio::test]
async fn or_nested_inside_other_thing() {
let inner = route("/bar", get(|| async {})).or(route("/baz", get(|| async {})));
let app = nest("/foo", inner);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.get(format!("http://{}/foo/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn or_with_route_following() {
let one = route("/one", get(|| async { "one" }));
let two = route("/two", get(|| async { "two" }));
let app = one.or(two).route("/three", get(|| async { "three" }));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/one", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.get(format!("http://{}/two", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.get(format!("http://{}/three", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn layer() {
let one = route("/foo", get(|| async {}));