From c9c507aece33d4df366977a58d8782a631b428be Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 12 Jun 2021 20:50:30 +0200 Subject: [PATCH] Add support for websockets (#3) Basically a copy/paste of whats in warp. Example usage: ```rust use tower_web::{prelude::*, ws::{ws, WebSocket}}; let app = route("/ws", ws(handle_socket)); async fn handle_socket(mut socket: WebSocket) { while let Some(msg) = socket.recv().await { let msg = msg.unwrap(); socket.send(msg).await.unwrap(); } } ``` --- Cargo.toml | 21 ++- examples/websocket.rs | 63 +++++++ examples/websocket/index.html | 1 + examples/websocket/script.js | 9 + src/lib.rs | 4 + src/response.rs | 11 ++ src/routing.rs | 6 +- src/ws/future.rs | 68 +++++++ src/ws/mod.rs | 337 ++++++++++++++++++++++++++++++++++ 9 files changed, 518 insertions(+), 2 deletions(-) create mode 100644 examples/websocket.rs create mode 100644 examples/websocket/index.html create mode 100644 examples/websocket/script.js create mode 100644 src/ws/future.rs create mode 100644 src/ws/mod.rs diff --git a/Cargo.toml b/Cargo.toml index b908b562..98561cf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,9 @@ readme = "README.md" repository = "https://github.com/davidpdrsn/tower-web" version = "0.1.0" +[features] +ws = ["tokio-tungstenite", "sha-1", "base64"] + [dependencies] async-trait = "0.1" bytes = "1.0" @@ -28,16 +31,32 @@ tokio = { version = "1", features = ["time"] } tower = { version = "0.4", features = ["util", "buffer"] } tower-http = { version = "0.1", features = ["add-extension"] } +# optional dependencies +tokio-tungstenite = { optional = true, version = "0.14" } +sha-1 = { optional = true, version = "0.9.6" } +base64 = { optional = true, version = "0.13" } + [dev-dependencies] hyper = { version = "0.14", features = ["full"] } reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } -tower = { version = "0.4", features = ["util", "make", "timeout", "limit", "load-shed", "steer"] } tracing = "0.1" tracing-subscriber = "0.2" uuid = "0.8" +[dev-dependencies.tower] +version = "0.4" +features = [ + "util", + "make", + "timeout", + "limit", + "load-shed", + "steer", + "filter", +] + [dev-dependencies.tower-http] version = "0.1" features = [ diff --git a/examples/websocket.rs b/examples/websocket.rs new file mode 100644 index 00000000..f58ce0ad --- /dev/null +++ b/examples/websocket.rs @@ -0,0 +1,63 @@ +//! Example websocket server. +//! +//! Run with +//! +//! ``` +//! RUST_LOG=tower_http=debug,key_value_store=trace \ +//! cargo run \ +//! --features ws \ +//! --example websocket +//! ``` + +use http::StatusCode; +use hyper::Server; +use std::net::SocketAddr; +use tower::make::Shared; +use tower_http::{ + services::ServeDir, + trace::{DefaultMakeSpan, TraceLayer}, +}; +use tower_web::{ + prelude::*, + routing::nest, + service::ServiceExt, + ws::{ws, Message, WebSocket}, +}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // build our application with some routes + let app = nest( + "/", + ServeDir::new("examples/websocket") + .append_index_html_on_directories(true) + .handle_error(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string())), + ) + // routes are matched from bottom to top, so we have to put `nest` at the + // top since it matches all routes + .route("/ws", ws(handle_socket)) + // logging so we can see whats going on + .layer( + TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)), + ); + + // run it with hyper + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + let server = Server::bind(&addr).serve(Shared::new(app)); + server.await.unwrap(); +} + +async fn handle_socket(mut socket: WebSocket) { + if let Some(msg) = socket.recv().await { + let msg = msg.unwrap(); + println!("Client says: {:?}", msg); + } + + loop { + socket.send(Message::text("Hi!")).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + } +} diff --git a/examples/websocket/index.html b/examples/websocket/index.html new file mode 100644 index 00000000..390bb86b --- /dev/null +++ b/examples/websocket/index.html @@ -0,0 +1 @@ + diff --git a/examples/websocket/script.js b/examples/websocket/script.js new file mode 100644 index 00000000..3f166736 --- /dev/null +++ b/examples/websocket/script.js @@ -0,0 +1,9 @@ +const socket = new WebSocket('ws://localhost:3000/ws'); + +socket.addEventListener('open', function (event) { + socket.send('Hello Server!'); +}); + +socket.addEventListener('message', function (event) { + console.log('Message from server ', event.data); +}); diff --git a/src/lib.rs b/src/lib.rs index ccf9c2fb..031f8e5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -610,6 +610,10 @@ pub mod response; pub mod routing; pub mod service; +#[cfg(feature = "ws")] +#[cfg_attr(docsrs, doc(cfg(feature = "ws")))] +pub mod ws; + #[cfg(test)] mod tests; diff --git a/src/response.rs b/src/response.rs index b32504f6..dea64edd 100644 --- a/src/response.rs +++ b/src/response.rs @@ -147,6 +147,17 @@ where } } +impl IntoResponse for (HeaderMap, T) +where + T: Into, +{ + fn into_response(self) -> Response { + let mut res = Response::new(self.1.into()); + *res.headers_mut() = self.0; + res + } +} + impl IntoResponse for (StatusCode, HeaderMap, T) where T: Into, diff --git a/src/routing.rs b/src/routing.rs index 9d4c3e45..19579e45 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -770,12 +770,16 @@ where fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { let path_and_query = if let Some(path_and_query) = uri.path_and_query() { - let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { + let mut new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { path } else { path_and_query.path() }; + if new_path.is_empty() { + new_path = "/"; + } + if let Some(query) = path_and_query.query() { Some( format!("{}?{}", new_path, query) diff --git a/src/ws/future.rs b/src/ws/future.rs new file mode 100644 index 00000000..9dd56223 --- /dev/null +++ b/src/ws/future.rs @@ -0,0 +1,68 @@ +//! Future types. + +use bytes::Bytes; +use http::{HeaderValue, Response, StatusCode}; +use http_body::Full; +use sha1::{Digest, Sha1}; +use std::{ + convert::Infallible, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade). +#[derive(Debug)] +pub struct ResponseFuture(Result, Option<(StatusCode, &'static str)>>); + +impl ResponseFuture { + pub(super) fn ok(key: HeaderValue) -> Self { + Self(Ok(Some(key))) + } + + pub(super) fn err(status: StatusCode, body: &'static str) -> Self { + Self(Err(Some((status, body)))) + } +} + +impl Future for ResponseFuture { + type Output = Result>, Infallible>; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let res = match self.get_mut().0.as_mut() { + Ok(key) => Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header( + http::header::CONNECTION, + HeaderValue::from_str("upgrade").unwrap(), + ) + .header( + http::header::UPGRADE, + HeaderValue::from_str("websocket").unwrap(), + ) + .header( + http::header::SEC_WEBSOCKET_ACCEPT, + sign(key.take().unwrap().as_bytes()), + ) + .body(Full::new(Bytes::new())) + .unwrap(), + Err(err) => { + let (status, body) = err.take().unwrap(); + Response::builder() + .status(status) + .body(Full::from(body)) + .unwrap() + } + }; + + Poll::Ready(Ok(res)) + } +} + +fn sign(key: &[u8]) -> HeaderValue { + let mut sha1 = Sha1::default(); + sha1.update(key); + sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]); + let b64 = Bytes::from(base64::encode(&sha1.finalize())); + HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value") +} diff --git a/src/ws/mod.rs b/src/ws/mod.rs new file mode 100644 index 00000000..32d685fb --- /dev/null +++ b/src/ws/mod.rs @@ -0,0 +1,337 @@ +//! Handle websocket connections. +//! +//! # Example +//! +//! ``` +//! use tower_web::{prelude::*, ws::{ws, WebSocket}}; +//! +//! let app = route("/ws", ws(handle_socket)); +//! +//! async fn handle_socket(mut socket: WebSocket) { +//! while let Some(msg) = socket.recv().await { +//! let msg = msg.unwrap(); +//! socket.send(msg).await.unwrap(); +//! } +//! } +//! ``` + +use crate::{routing::EmptyRouter, service::OnMethod}; +use bytes::Bytes; +use future::ResponseFuture; +use futures_util::{sink::SinkExt, stream::StreamExt}; +use http::{ + header::{self, HeaderName}, + HeaderValue, Request, Response, StatusCode, +}; +use http_body::Full; +use hyper::upgrade::{OnUpgrade, Upgraded}; +use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll}; +use tokio_tungstenite::{ + tungstenite::protocol::{self, WebSocketConfig}, + WebSocketStream, +}; +use tower::{BoxError, Service}; + +pub mod future; + +/// Create a new [`WebSocketUpgrade`] service that will call the closure with +/// each connection. +/// +/// See the [module docs](crate::ws) for more details. +pub fn ws(callback: F) -> OnMethod, EmptyRouter> +where + F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static, + Fut: Future + Send + 'static, +{ + let svc = WebSocketUpgrade { + callback, + config: WebSocketConfig::default(), + }; + crate::service::get(svc) +} + +/// [`Service`] that ugprades connections to websockets and spawns a task to +/// handle the stream. +/// +/// Created with [`ws`]. +/// +/// See the [module docs](crate::ws) for more details. +#[derive(Clone)] +pub struct WebSocketUpgrade { + callback: F, + config: WebSocketConfig, +} + +impl fmt::Debug for WebSocketUpgrade { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WebSocketUpgrade") + .field("callback", &format_args!("{}", std::any::type_name::())) + .field("config", &self.config) + .finish() + } +} + +impl WebSocketUpgrade { + /// Set the size of the internal message send queue. + pub fn max_send_queue(mut self, max: usize) -> Self { + self.config.max_send_queue = Some(max); + self + } + + /// Set the maximum message size (defaults to 64 megabytes) + pub fn max_message_size(mut self, max: usize) -> Self { + self.config.max_message_size = Some(max); + self + } + + /// Set the maximum frame size (defaults to 16 megabytes) + pub fn max_frame_size(mut self, max: usize) -> Self { + self.config.max_frame_size = Some(max); + self + } +} + +impl Service> for WebSocketUpgrade +where + F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static, + Fut: Future + Send + 'static, +{ + type Response = Response>; + type Error = Infallible; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + if !header_eq( + &req, + header::CONNECTION, + HeaderValue::from_static("upgrade"), + ) { + return ResponseFuture::err( + StatusCode::BAD_REQUEST, + "Connection header did not include 'upgrade'", + ); + } + + if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) { + return ResponseFuture::err( + StatusCode::BAD_REQUEST, + "`Upgrade` header did not include 'websocket'", + ); + } + + if !header_eq( + &req, + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + ) { + return ResponseFuture::err( + StatusCode::BAD_REQUEST, + "`Sec-Websocket-Version` header did not include '13'", + ); + } + + let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) { + key + } else { + return ResponseFuture::err( + StatusCode::BAD_REQUEST, + "`Sec-Websocket-Key` header missing", + ); + }; + + let on_upgrade = req.extensions_mut().remove::().unwrap(); + + let config = self.config; + let callback = self.callback.clone(); + + tokio::spawn(async move { + let upgraded = on_upgrade.await.unwrap(); + let socket = + WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) + .await; + let socket = WebSocket { inner: socket }; + callback(socket).await; + }); + + ResponseFuture::ok(key) + } +} + +fn header_eq(req: &Request, key: HeaderName, value: HeaderValue) -> bool { + if let Some(header) = req.headers().get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) + } else { + false + } +} + +/// A stream of websocket messages. +#[derive(Debug)] +pub struct WebSocket { + inner: WebSocketStream, +} + +impl WebSocket { + /// Receive another message. + /// + /// Returns `None` is stream has closed. + pub async fn recv(&mut self) -> Option> { + self.inner + .next() + .await + .map(|result| result.map_err(Into::into).map(|inner| Message { inner })) + } + + /// Send a message. + pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> { + self.inner.send(msg.inner).await.map_err(Into::into) + } + + /// Gracefully close this websocket. + pub async fn close(mut self) -> Result<(), BoxError> { + self.inner.close(None).await.map_err(Into::into) + } +} + +/// A WebSocket message. +#[derive(Eq, PartialEq, Clone)] +pub struct Message { + inner: protocol::Message, +} + +impl Message { + /// Construct a new Text `Message`. + pub fn text(s: S) -> Message + where + S: Into, + { + Message { + inner: protocol::Message::text(s), + } + } + + /// Construct a new Binary `Message`. + pub fn binary(v: V) -> Message + where + V: Into>, + { + Message { + inner: protocol::Message::binary(v), + } + } + + /// Construct a new Ping `Message`. + pub fn ping>>(v: V) -> Message { + Message { + inner: protocol::Message::Ping(v.into()), + } + } + + /// Construct a new Pong `Message`. + /// + /// Note that one rarely needs to manually construct a Pong message because + /// the underlying tungstenite socket automatically responds to the Ping + /// messages it receives. Manual construction might still be useful in some + /// cases like in tests or to send unidirectional heartbeats. + pub fn pong>>(v: V) -> Message { + Message { + inner: protocol::Message::Pong(v.into()), + } + } + + /// Construct the default Close `Message`. + pub fn close() -> Message { + Message { + inner: protocol::Message::Close(None), + } + } + + /// Construct a Close `Message` with a code and reason. + pub fn close_with(code: C, reason: R) -> Message + where + C: Into, + R: Into>, + { + Message { + inner: protocol::Message::Close(Some(protocol::frame::CloseFrame { + code: protocol::frame::coding::CloseCode::from(code.into()), + reason: reason.into(), + })), + } + } + + /// Returns true if this message is a Text message. + pub fn is_text(&self) -> bool { + self.inner.is_text() + } + + /// Returns true if this message is a Binary message. + pub fn is_binary(&self) -> bool { + self.inner.is_binary() + } + + /// Returns true if this message a is a Close message. + pub fn is_close(&self) -> bool { + self.inner.is_close() + } + + /// Returns true if this message is a Ping message. + pub fn is_ping(&self) -> bool { + self.inner.is_ping() + } + + /// Returns true if this message is a Pong message. + pub fn is_pong(&self) -> bool { + self.inner.is_pong() + } + + /// Try to get the close frame (close code and reason) + pub fn close_frame(&self) -> Option<(u16, &str)> { + if let protocol::Message::Close(Some(close_frame)) = &self.inner { + Some((close_frame.code.into(), close_frame.reason.as_ref())) + } else { + None + } + } + + /// Try to get a reference to the string text, if this is a Text message. + pub fn to_str(&self) -> Option<&str> { + if let protocol::Message::Text(s) = &self.inner { + Some(s) + } else { + None + } + } + + /// Return the bytes of this message, if the message can contain data. + pub fn as_bytes(&self) -> &[u8] { + match self.inner { + protocol::Message::Text(ref s) => s.as_bytes(), + protocol::Message::Binary(ref v) => v, + protocol::Message::Ping(ref v) => v, + protocol::Message::Pong(ref v) => v, + protocol::Message::Close(_) => &[], + } + } + + /// Destructure this message into binary data. + pub fn into_bytes(self) -> Vec { + self.inner.into_data() + } +} + +impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl From for Vec { + fn from(msg: Message) -> Self { + msg.into_bytes() + } +}