diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 194c8b4d..089f8473 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801]) +- **fixed:** Fix routing issues when loading a `Router` via a dynamic library ([#1806]) + +[#1806]: https://github.com/tokio-rs/axum/pull/1806 [#1801]: https://github.com/tokio-rs/axum/pull/1801 diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index e7b717ea..b13c25bb 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -47,24 +47,12 @@ pub use self::method_routing::{ #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub(crate) struct RouteId(u32); -impl RouteId { - fn next() -> Self { - use std::sync::atomic::{AtomicU32, Ordering}; - // `AtomicU64` isn't supported on all platforms - static ID: AtomicU32 = AtomicU32::new(0); - let id = ID.fetch_add(1, Ordering::Relaxed); - if id == u32::MAX { - panic!("Over `u32::MAX` routes created. If you need this, please file an issue."); - } - Self(id) - } -} - /// The router type for composing handlers and services. pub struct Router { routes: HashMap>, node: Arc, fallback: Fallback, + prev_route_id: RouteId, } impl Clone for Router { @@ -73,6 +61,7 @@ impl Clone for Router { routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), + prev_route_id: self.prev_route_id, } } } @@ -117,6 +106,7 @@ where routes: Default::default(), node: Default::default(), fallback: Fallback::Default(Route::new(NotFound)), + prev_route_id: RouteId(0), } } @@ -134,7 +124,7 @@ where validate_path(path); - let id = RouteId::next(); + let id = self.next_route_id(); let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self .node @@ -189,7 +179,7 @@ where panic!("Paths must start with a `/`"); } - let id = RouteId::next(); + let id = self.next_route_id(); self.set_node(path, id); self.routes.insert(id, endpoint); self @@ -286,6 +276,7 @@ where routes, node, fallback, + prev_route_id: _, } = other.into(); for (id, route) in routes { @@ -335,6 +326,7 @@ where routes, node: self.node, fallback, + prev_route_id: self.prev_route_id, } } @@ -368,6 +360,7 @@ where routes, node: self.node, fallback: self.fallback, + prev_route_id: self.prev_route_id, } } @@ -419,6 +412,7 @@ where routes, node: self.node, fallback, + prev_route_id: self.prev_route_id, } } @@ -506,6 +500,16 @@ where Endpoint::NestedRouter(router) => router.call_with_state(req, state), } } + + fn next_route_id(&mut self) -> RouteId { + let next_id = self + .prev_route_id + .0 + .checked_add(1) + .expect("Over `u32::MAX` routes created. If you need this, please file an issue."); + self.prev_route_id = RouteId(next_id); + self.prev_route_id + } } impl Router<(), B>