Fix routing issues when loading a `Router` via a dynamic library (#1806)

This commit is contained in:
David Pedersen 2023-03-03 13:23:53 +01:00 committed by GitHub
parent 6075be60ed
commit 5a58edac16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 15 deletions

View File

@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801])
- **fixed:** Fix routing issues when loading a `Router` via a dynamic library ([#1806])
[#1806]: https://github.com/tokio-rs/axum/pull/1806
[#1801]: https://github.com/tokio-rs/axum/pull/1801

View File

@ -47,24 +47,12 @@ pub use self::method_routing::{
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteId(u32);
impl RouteId {
fn next() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
// `AtomicU64` isn't supported on all platforms
static ID: AtomicU32 = AtomicU32::new(0);
let id = ID.fetch_add(1, Ordering::Relaxed);
if id == u32::MAX {
panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
}
Self(id)
}
}
/// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> {
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
fallback: Fallback<S, B>,
prev_route_id: RouteId,
}
impl<S, B> Clone for Router<S, B> {
@ -73,6 +61,7 @@ impl<S, B> Clone for Router<S, B> {
routes: self.routes.clone(),
node: Arc::clone(&self.node),
fallback: self.fallback.clone(),
prev_route_id: self.prev_route_id,
}
}
}
@ -117,6 +106,7 @@ where
routes: Default::default(),
node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)),
prev_route_id: RouteId(0),
}
}
@ -134,7 +124,7 @@ where
validate_path(path);
let id = RouteId::next();
let id = self.next_route_id();
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node
@ -189,7 +179,7 @@ where
panic!("Paths must start with a `/`");
}
let id = RouteId::next();
let id = self.next_route_id();
self.set_node(path, id);
self.routes.insert(id, endpoint);
self
@ -286,6 +276,7 @@ where
routes,
node,
fallback,
prev_route_id: _,
} = other.into();
for (id, route) in routes {
@ -335,6 +326,7 @@ where
routes,
node: self.node,
fallback,
prev_route_id: self.prev_route_id,
}
}
@ -368,6 +360,7 @@ where
routes,
node: self.node,
fallback: self.fallback,
prev_route_id: self.prev_route_id,
}
}
@ -419,6 +412,7 @@ where
routes,
node: self.node,
fallback,
prev_route_id: self.prev_route_id,
}
}
@ -506,6 +500,16 @@ where
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
}
}
fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
.0
.checked_add(1)
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
self.prev_route_id = RouteId(next_id);
self.prev_route_id
}
}
impl<B> Router<(), B>