mirror of https://github.com/tokio-rs/axum
Add extractor for remote connection info (#55)
Fixes https://github.com/tokio-rs/axum/issues/43 With this you can get the remote address like so: ```rust use axum::{prelude::*, extract::ConnectInfo}; use std::net::SocketAddr; let app = route("/", get(handler)); async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String { format!("Hello {}", addr) } // Starting the app with `into_make_service_with_connect_info` is required // for `ConnectInfo` to work. let make_svc = app.into_make_service_with_connect_info::<SocketAddr, _>(); hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(make_svc) .await .expect("server failed"); ``` This API is fully generic and supports whatever transport layer you're using with Hyper. I've updated the unix domain socket example to extract `peer_creds` and `peer_addr`.
This commit is contained in:
parent
407aa533d7
commit
f67abd1ee2
|
@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Implement `Sink` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52))
|
||||
- Implement `Deref` most extractors ([#56](https://github.com/tokio-rs/axum/pull/56))
|
||||
- Return `405 Method Not Allowed` for unsupported method for route ([#63](https://github.com/tokio-rs/axum/pull/63))
|
||||
- Add extractor for remote connection info ([#55](https://github.com/tokio-rs/axum/pull/55))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ bytes = "1.0"
|
|||
futures-util = "0.3"
|
||||
http = "0.2"
|
||||
http-body = "0.4"
|
||||
hyper = { version = "0.14", features = ["server", "tcp"] }
|
||||
hyper = { version = "0.14", features = ["server", "tcp", "http1"] }
|
||||
pin-project = "1.0"
|
||||
regex = "1.5"
|
||||
serde = "1.0"
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use axum::prelude::*;
|
||||
use axum::{
|
||||
extract::connect_info::{self, ConnectInfo},
|
||||
prelude::*,
|
||||
};
|
||||
use futures::ready;
|
||||
use http::{Method, StatusCode, Uri};
|
||||
use hyper::{
|
||||
|
@ -9,9 +12,10 @@ use std::{
|
|||
io,
|
||||
path::PathBuf,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::net::UnixListener;
|
||||
use tokio::net::{unix::UCred, UnixListener};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::UnixStream,
|
||||
|
@ -35,10 +39,10 @@ async fn main() {
|
|||
|
||||
let uds = UnixListener::bind(path.clone()).unwrap();
|
||||
tokio::spawn(async {
|
||||
let app = route("/", get(|| async { "Hello, World!" }));
|
||||
let app = route("/", get(handler));
|
||||
|
||||
hyper::Server::builder(ServerAccept { uds })
|
||||
.serve(app.into_make_service())
|
||||
.serve(app.into_make_service_with_connect_info::<UdsConnectInfo, _>())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
@ -67,6 +71,12 @@ async fn main() {
|
|||
assert_eq!(body, "Hello, World!");
|
||||
}
|
||||
|
||||
async fn handler(ConnectInfo(info): ConnectInfo<UdsConnectInfo>) -> &'static str {
|
||||
println!("new connection from `{:?}`", info);
|
||||
|
||||
"Hello, World!"
|
||||
}
|
||||
|
||||
struct ServerAccept {
|
||||
uds: UnixListener,
|
||||
}
|
||||
|
@ -124,3 +134,23 @@ impl Connection for ClientConnection {
|
|||
Connected::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct UdsConnectInfo {
|
||||
peer_addr: Arc<tokio::net::unix::SocketAddr>,
|
||||
peer_cred: UCred,
|
||||
}
|
||||
|
||||
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
|
||||
type ConnectInfo = Self;
|
||||
|
||||
fn connect_info(target: &UnixStream) -> Self::ConnectInfo {
|
||||
let peer_addr = target.peer_addr().unwrap();
|
||||
let peer_cred = target.peer_cred().unwrap();
|
||||
|
||||
Self {
|
||||
peer_addr: Arc::new(peer_addr),
|
||||
peer_cred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
//! Extractor for getting connection information from a client.
|
||||
//!
|
||||
//! See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
|
||||
//!
|
||||
//! [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
|
||||
|
||||
use super::{Extension, FromRequest, RequestParts};
|
||||
use async_trait::async_trait;
|
||||
use hyper::server::conn::AddrStream;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
net::SocketAddr,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::Service;
|
||||
use tower_http::add_extension::AddExtension;
|
||||
|
||||
/// A [`MakeService`] created from a router.
|
||||
///
|
||||
/// See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
|
||||
///
|
||||
/// [`MakeService`]: tower::make::MakeService
|
||||
/// [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
|
||||
pub struct IntoMakeServiceWithConnectInfo<S, C> {
|
||||
svc: S,
|
||||
_connect_info: PhantomData<fn() -> C>,
|
||||
}
|
||||
|
||||
impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
|
||||
pub(crate) fn new(svc: S) -> Self {
|
||||
Self {
|
||||
svc,
|
||||
_connect_info: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("IntoMakeServiceWithConnectInfo")
|
||||
.field("svc", &self.svc)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that connected IO resources implement and use to produce information
|
||||
/// about the connection.
|
||||
///
|
||||
/// The goal for this trait is to allow users to implement custom IO types that
|
||||
/// can still provide the same connection metadata.
|
||||
///
|
||||
/// See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
|
||||
///
|
||||
/// [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
|
||||
pub trait Connected<T> {
|
||||
/// The connection information type the IO resources generates.
|
||||
type ConnectInfo: Clone + Send + Sync + 'static;
|
||||
|
||||
/// Create type holding information about the connection.
|
||||
fn connect_info(target: T) -> Self::ConnectInfo;
|
||||
}
|
||||
|
||||
impl Connected<&AddrStream> for SocketAddr {
|
||||
type ConnectInfo = SocketAddr;
|
||||
|
||||
fn connect_info(target: &AddrStream) -> Self::ConnectInfo {
|
||||
target.remote_addr()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
|
||||
where
|
||||
S: Clone,
|
||||
C: Connected<T>,
|
||||
{
|
||||
type Response = AddExtension<S, ConnectInfo<C::ConnectInfo>>;
|
||||
type Error = Infallible;
|
||||
type Future = ResponseFuture<Self::Response>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, target: T) -> Self::Future {
|
||||
let connect_info = ConnectInfo(C::connect_info(target));
|
||||
let svc = AddExtension::new(self.svc.clone(), connect_info);
|
||||
ResponseFuture(futures_util::future::ok(svc))
|
||||
}
|
||||
}
|
||||
|
||||
opaque_future! {
|
||||
/// Response future for [`IntoMakeServiceWithConnectInfo`].
|
||||
pub type ResponseFuture<T> =
|
||||
futures_util::future::Ready<Result<T, Infallible>>;
|
||||
}
|
||||
|
||||
/// Extractor for getting connection information produced by a [`Connected`].
|
||||
///
|
||||
/// Note this extractor requires you to use
|
||||
/// [`RoutingDsl::into_make_service_with_connect_info`] to run your app
|
||||
/// otherwise it will fail at runtime.
|
||||
///
|
||||
/// See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
|
||||
///
|
||||
/// [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ConnectInfo<T>(pub T);
|
||||
|
||||
#[async_trait]
|
||||
impl<B, T> FromRequest<B> for ConnectInfo<T>
|
||||
where
|
||||
B: Send,
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
type Rejection = <Extension<Self> as FromRequest<B>>::Rejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let Extension(connect_info) = Extension::<Self>::from_request(req).await?;
|
||||
Ok(connect_info)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prelude::*;
|
||||
use hyper::Server;
|
||||
use std::net::{SocketAddr, TcpListener};
|
||||
|
||||
#[tokio::test]
|
||||
async fn socket_addr() {
|
||||
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
|
||||
format!("{}", addr)
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let app = route("/", get(handler));
|
||||
let server = Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>());
|
||||
tx.send(()).unwrap();
|
||||
server.await.expect("server error");
|
||||
});
|
||||
rx.await.unwrap();
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
assert!(body.starts_with("127.0.0.1:"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn custom() {
|
||||
#[derive(Clone, Debug)]
|
||||
struct MyConnectInfo {
|
||||
value: &'static str,
|
||||
}
|
||||
|
||||
impl Connected<&AddrStream> for MyConnectInfo {
|
||||
type ConnectInfo = Self;
|
||||
|
||||
fn connect_info(_target: &AddrStream) -> Self::ConnectInfo {
|
||||
Self {
|
||||
value: "it worked!",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
|
||||
addr.value
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let app = route("/", get(handler));
|
||||
let server = Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(app.into_make_service_with_connect_info::<MyConnectInfo, _>());
|
||||
tx.send(()).unwrap();
|
||||
server.await.expect("server error");
|
||||
});
|
||||
rx.await.unwrap();
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
assert_eq!(body, "it worked!");
|
||||
}
|
||||
}
|
|
@ -260,12 +260,16 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
pub mod connect_info;
|
||||
pub mod extractor_middleware;
|
||||
pub mod rejection;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use self::extractor_middleware::extractor_middleware;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use self::connect_info::ConnectInfo;
|
||||
|
||||
#[cfg(feature = "multipart")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
|
||||
pub mod multipart;
|
||||
|
@ -904,6 +908,8 @@ where
|
|||
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
|
||||
#[derive(Debug)]
|
||||
pub struct BodyStream<B = crate::body::Body>(B);
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
use crate::{
|
||||
body::{box_body, BoxBody},
|
||||
buffer::MpscBuffer,
|
||||
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||
response::IntoResponse,
|
||||
util::ByteStr,
|
||||
};
|
||||
|
@ -266,6 +267,93 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized {
|
|||
{
|
||||
tower::make::Shared::new(self)
|
||||
}
|
||||
|
||||
/// Convert this router into a [`MakeService`], that will store `C`'s
|
||||
/// associated `ConnectInfo` in a request extension such that [`ConnectInfo`]
|
||||
/// can extract it.
|
||||
///
|
||||
/// This enables extracting things like the client's remote address.
|
||||
///
|
||||
/// Extracting [`std::net::SocketAddr`] is supported out of the box:
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{prelude::*, extract::ConnectInfo};
|
||||
/// use std::net::SocketAddr;
|
||||
///
|
||||
/// let app = route("/", get(handler));
|
||||
///
|
||||
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
|
||||
/// format!("Hello {}", addr)
|
||||
/// }
|
||||
///
|
||||
/// # async {
|
||||
/// hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap())
|
||||
/// .serve(
|
||||
/// app.into_make_service_with_connect_info::<SocketAddr, _>()
|
||||
/// )
|
||||
/// .await
|
||||
/// .expect("server failed");
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// You can implement custom a [`Connected`] like so:
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// prelude::*,
|
||||
/// extract::connect_info::{ConnectInfo, Connected},
|
||||
/// };
|
||||
/// use hyper::server::conn::AddrStream;
|
||||
///
|
||||
/// let app = route("/", get(handler));
|
||||
///
|
||||
/// async fn handler(
|
||||
/// ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
|
||||
/// ) -> String {
|
||||
/// format!("Hello {:?}", my_connect_info)
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone, Debug)]
|
||||
/// struct MyConnectInfo {
|
||||
/// // ...
|
||||
/// }
|
||||
///
|
||||
/// impl Connected<&AddrStream> for MyConnectInfo {
|
||||
/// type ConnectInfo = MyConnectInfo;
|
||||
///
|
||||
/// fn connect_info(target: &AddrStream) -> Self::ConnectInfo {
|
||||
/// MyConnectInfo {
|
||||
/// // ...
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// # async {
|
||||
/// hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap())
|
||||
/// .serve(
|
||||
/// app.into_make_service_with_connect_info::<MyConnectInfo, _>()
|
||||
/// )
|
||||
/// .await
|
||||
/// .expect("server failed");
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// See the [unix domain socket example][uds] for an example of how to use
|
||||
/// this to collect UDS connection info.
|
||||
///
|
||||
/// [`MakeService`]: tower::make::MakeService
|
||||
/// [`Connected`]: crate::extract::connect_info::Connected
|
||||
/// [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo
|
||||
/// [uds]: https://github.com/tokio-rs/axum/blob/main/examples/unix_domain_socket.rs
|
||||
fn into_make_service_with_connect_info<C, Target>(
|
||||
self,
|
||||
) -> IntoMakeServiceWithConnectInfo<Self, C>
|
||||
where
|
||||
Self: Clone,
|
||||
C: Connected<Target>,
|
||||
{
|
||||
IntoMakeServiceWithConnectInfo::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, F> RoutingDsl for Route<S, F> {}
|
||||
|
|
Loading…
Reference in New Issue