Add extractor for remote connection info (#55)

Fixes https://github.com/tokio-rs/axum/issues/43

With this you can get the remote address like so:

```rust
use axum::{prelude::*, extract::ConnectInfo};
use std::net::SocketAddr;

let app = route("/", get(handler));

async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
    format!("Hello {}", addr)
}

// Starting the app with `into_make_service_with_connect_info` is required
// for `ConnectInfo` to work.
let make_svc = app.into_make_service_with_connect_info::<SocketAddr, _>();

hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap())
    .serve(make_svc)
    .await
    .expect("server failed");
```

This API is fully generic and supports whatever transport layer you're using with Hyper. I've updated the unix domain socket example to extract `peer_creds` and `peer_addr`.
This commit is contained in:
David Pedersen 2021-07-31 21:36:30 +02:00 committed by GitHub
parent 407aa533d7
commit f67abd1ee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 333 additions and 5 deletions

View File

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Implement `Sink` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52))
- Implement `Deref` most extractors ([#56](https://github.com/tokio-rs/axum/pull/56))
- Return `405 Method Not Allowed` for unsupported method for route ([#63](https://github.com/tokio-rs/axum/pull/63))
- Add extractor for remote connection info ([#55](https://github.com/tokio-rs/axum/pull/55))
## Breaking changes

View File

@ -23,7 +23,7 @@ bytes = "1.0"
futures-util = "0.3"
http = "0.2"
http-body = "0.4"
hyper = { version = "0.14", features = ["server", "tcp"] }
hyper = { version = "0.14", features = ["server", "tcp", "http1"] }
pin-project = "1.0"
regex = "1.5"
serde = "1.0"

View File

@ -1,4 +1,7 @@
use axum::prelude::*;
use axum::{
extract::connect_info::{self, ConnectInfo},
prelude::*,
};
use futures::ready;
use http::{Method, StatusCode, Uri};
use hyper::{
@ -9,9 +12,10 @@ use std::{
io,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::net::UnixListener;
use tokio::net::{unix::UCred, UnixListener};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::UnixStream,
@ -35,10 +39,10 @@ async fn main() {
let uds = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async {
let app = route("/", get(|| async { "Hello, World!" }));
let app = route("/", get(handler));
hyper::Server::builder(ServerAccept { uds })
.serve(app.into_make_service())
.serve(app.into_make_service_with_connect_info::<UdsConnectInfo, _>())
.await
.unwrap();
});
@ -67,6 +71,12 @@ async fn main() {
assert_eq!(body, "Hello, World!");
}
async fn handler(ConnectInfo(info): ConnectInfo<UdsConnectInfo>) -> &'static str {
println!("new connection from `{:?}`", info);
"Hello, World!"
}
struct ServerAccept {
uds: UnixListener,
}
@ -124,3 +134,23 @@ impl Connection for ClientConnection {
Connected::new()
}
}
#[derive(Clone, Debug)]
struct UdsConnectInfo {
peer_addr: Arc<tokio::net::unix::SocketAddr>,
peer_cred: UCred,
}
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
type ConnectInfo = Self;
fn connect_info(target: &UnixStream) -> Self::ConnectInfo {
let peer_addr = target.peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap();
Self {
peer_addr: Arc::new(peer_addr),
peer_cred,
}
}
}

203
src/extract/connect_info.rs Normal file
View File

@ -0,0 +1,203 @@
//! Extractor for getting connection information from a client.
//!
//! See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
//!
//! [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
use super::{Extension, FromRequest, RequestParts};
use async_trait::async_trait;
use hyper::server::conn::AddrStream;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
net::SocketAddr,
task::{Context, Poll},
};
use tower::Service;
use tower_http::add_extension::AddExtension;
/// A [`MakeService`] created from a router.
///
/// See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
/// [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
pub struct IntoMakeServiceWithConnectInfo<S, C> {
svc: S,
_connect_info: PhantomData<fn() -> C>,
}
impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
pub(crate) fn new(svc: S) -> Self {
Self {
svc,
_connect_info: PhantomData,
}
}
}
impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoMakeServiceWithConnectInfo")
.field("svc", &self.svc)
.finish()
}
}
/// Trait that connected IO resources implement and use to produce information
/// about the connection.
///
/// The goal for this trait is to allow users to implement custom IO types that
/// can still provide the same connection metadata.
///
/// See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
///
/// [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
pub trait Connected<T> {
/// The connection information type the IO resources generates.
type ConnectInfo: Clone + Send + Sync + 'static;
/// Create type holding information about the connection.
fn connect_info(target: T) -> Self::ConnectInfo;
}
impl Connected<&AddrStream> for SocketAddr {
type ConnectInfo = SocketAddr;
fn connect_info(target: &AddrStream) -> Self::ConnectInfo {
target.remote_addr()
}
}
impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
where
S: Clone,
C: Connected<T>,
{
type Response = AddExtension<S, ConnectInfo<C::ConnectInfo>>;
type Error = Infallible;
type Future = ResponseFuture<Self::Response>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, target: T) -> Self::Future {
let connect_info = ConnectInfo(C::connect_info(target));
let svc = AddExtension::new(self.svc.clone(), connect_info);
ResponseFuture(futures_util::future::ok(svc))
}
}
opaque_future! {
/// Response future for [`IntoMakeServiceWithConnectInfo`].
pub type ResponseFuture<T> =
futures_util::future::Ready<Result<T, Infallible>>;
}
/// Extractor for getting connection information produced by a [`Connected`].
///
/// Note this extractor requires you to use
/// [`RoutingDsl::into_make_service_with_connect_info`] to run your app
/// otherwise it will fail at runtime.
///
/// See [`RoutingDsl::into_make_service_with_connect_info`] for more details.
///
/// [`RoutingDsl::into_make_service_with_connect_info`]: crate::routing::RoutingDsl::into_make_service_with_connect_info
#[derive(Clone, Copy, Debug)]
pub struct ConnectInfo<T>(pub T);
#[async_trait]
impl<B, T> FromRequest<B> for ConnectInfo<T>
where
B: Send,
T: Clone + Send + Sync + 'static,
{
type Rejection = <Extension<Self> as FromRequest<B>>::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request(req).await?;
Ok(connect_info)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use hyper::Server;
use std::net::{SocketAddr, TcpListener};
#[tokio::test]
async fn socket_addr() {
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
format!("{}", addr)
}
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let app = route("/", get(handler));
let server = Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>());
tx.send(()).unwrap();
server.await.expect("server error");
});
rx.await.unwrap();
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
let body = res.text().await.unwrap();
assert!(body.starts_with("127.0.0.1:"));
}
#[tokio::test]
async fn custom() {
#[derive(Clone, Debug)]
struct MyConnectInfo {
value: &'static str,
}
impl Connected<&AddrStream> for MyConnectInfo {
type ConnectInfo = Self;
fn connect_info(_target: &AddrStream) -> Self::ConnectInfo {
Self {
value: "it worked!",
}
}
}
async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
addr.value
}
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let app = route("/", get(handler));
let server = Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service_with_connect_info::<MyConnectInfo, _>());
tx.send(()).unwrap();
server.await.expect("server error");
});
rx.await.unwrap();
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "it worked!");
}
}

View File

@ -260,12 +260,16 @@ use std::{
task::{Context, Poll},
};
pub mod connect_info;
pub mod extractor_middleware;
pub mod rejection;
#[doc(inline)]
pub use self::extractor_middleware::extractor_middleware;
#[doc(inline)]
pub use self::connect_info::ConnectInfo;
#[cfg(feature = "multipart")]
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
pub mod multipart;
@ -904,6 +908,8 @@ where
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
#[derive(Debug)]
pub struct BodyStream<B = crate::body::Body>(B);

View File

@ -3,6 +3,7 @@
use crate::{
body::{box_body, BoxBody},
buffer::MpscBuffer,
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
response::IntoResponse,
util::ByteStr,
};
@ -266,6 +267,93 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized {
{
tower::make::Shared::new(self)
}
/// Convert this router into a [`MakeService`], that will store `C`'s
/// associated `ConnectInfo` in a request extension such that [`ConnectInfo`]
/// can extract it.
///
/// This enables extracting things like the client's remote address.
///
/// Extracting [`std::net::SocketAddr`] is supported out of the box:
///
/// ```
/// use axum::{prelude::*, extract::ConnectInfo};
/// use std::net::SocketAddr;
///
/// let app = route("/", get(handler));
///
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
/// format!("Hello {}", addr)
/// }
///
/// # async {
/// hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap())
/// .serve(
/// app.into_make_service_with_connect_info::<SocketAddr, _>()
/// )
/// .await
/// .expect("server failed");
/// # };
/// ```
///
/// You can implement custom a [`Connected`] like so:
///
/// ```
/// use axum::{
/// prelude::*,
/// extract::connect_info::{ConnectInfo, Connected},
/// };
/// use hyper::server::conn::AddrStream;
///
/// let app = route("/", get(handler));
///
/// async fn handler(
/// ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
/// ) -> String {
/// format!("Hello {:?}", my_connect_info)
/// }
///
/// #[derive(Clone, Debug)]
/// struct MyConnectInfo {
/// // ...
/// }
///
/// impl Connected<&AddrStream> for MyConnectInfo {
/// type ConnectInfo = MyConnectInfo;
///
/// fn connect_info(target: &AddrStream) -> Self::ConnectInfo {
/// MyConnectInfo {
/// // ...
/// }
/// }
/// }
///
/// # async {
/// hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap())
/// .serve(
/// app.into_make_service_with_connect_info::<MyConnectInfo, _>()
/// )
/// .await
/// .expect("server failed");
/// # };
/// ```
///
/// See the [unix domain socket example][uds] for an example of how to use
/// this to collect UDS connection info.
///
/// [`MakeService`]: tower::make::MakeService
/// [`Connected`]: crate::extract::connect_info::Connected
/// [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo
/// [uds]: https://github.com/tokio-rs/axum/blob/main/examples/unix_domain_socket.rs
fn into_make_service_with_connect_info<C, Target>(
self,
) -> IntoMakeServiceWithConnectInfo<Self, C>
where
Self: Clone,
C: Connected<Target>,
{
IntoMakeServiceWithConnectInfo::new(self)
}
}
impl<S, F> RoutingDsl for Route<S, F> {}