Update middlewares

This commit is contained in:
photino 2023-01-05 00:24:55 +08:00
parent 38ec03b8f5
commit ff6f05aed6
15 changed files with 271 additions and 80 deletions

View File

@ -1,2 +1,2 @@
# zino
A minimal MVC framework
A minimal web framework.

View File

@ -1,6 +1,7 @@
[package]
name = "axum-app"
version = "0.2.1"
rust-version = "1.68"
edition = "2021"
publish = false

View File

@ -25,3 +25,6 @@ port = 5432
user = "postgres"
password = "postgres"
database = "data_cube"
[tracing]
filter = "warn,zino=info,zino_core=info"

View File

@ -2,8 +2,11 @@
name = "zino-core"
description = "Core types and traits for zino."
version = "0.2.1"
rust-version = "1.68"
edition = "2021"
license = "MIT"
categories = ["asynchronous", "network-programming", "web-programming::http-server"]
keywords = ["http", "web", "framework"]
homepage = "https://github.com/photino/zino"
repository = "https://github.com/photino/zino"
documentation = "https://docs.rs/zino-core"
@ -24,8 +27,9 @@ serde = { version = "1.0.152", features = ["derive"] }
serde_json = { version = "1.0.91" }
sha-1 = { version = "0.10.1" }
sha2 = { version = "0.10.6" }
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "uuid", "time", "json"] }
sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "postgres", "uuid", "time", "json"] }
time = { version = "0.3.17", features = ["local-offset", "parsing", "serde"] }
tracing = { version = "0.1.37" }
toml = { version = "0.5.10" }
url = { version = "2.3.1" }
uuid = { version = "1.2.2", features = ["serde", "v4"] }

View File

@ -254,7 +254,7 @@ impl<S: ResponseCode> Response<S> {
) {
match Metric::new(name.into(), dur.into(), desc.into()) {
Ok(entry) => self.server_timing.push(entry),
Err(err) => eprintln!("{err}"),
Err(err) => tracing::error!("{err}"),
}
}
@ -336,14 +336,14 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
fn from(mut response: Response<http::StatusCode>) -> Self {
let mut res = match response.content_type {
Some(ref content_type) => match serde_json::to_vec(&response.data) {
Ok(bytes) => {
let mut res = http::Response::new(Full::from(bytes));
res.headers_mut().insert(
Ok(bytes) => http::Response::builder()
.status(response.status_code)
.header(
header::CONTENT_TYPE,
HeaderValue::from_str(content_type.as_str()).unwrap(),
);
res
}
)
.body(Full::from(bytes))
.unwrap(),
Err(err) => http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/plain")
@ -352,15 +352,16 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
},
None => match serde_json::to_vec(&response) {
Ok(bytes) => {
let mut res = http::Response::new(Full::from(bytes));
let content_type = if response.is_success() {
"application/json"
} else {
"application/problem+json"
};
res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type));
res
http::Response::builder()
.status(response.status_code)
.header(header::CONTENT_TYPE, HeaderValue::from_static(content_type))
.body(Full::from(bytes))
.unwrap()
}
Err(err) => http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)

View File

@ -45,7 +45,7 @@ impl State {
.expect("fail to parse toml value");
match config {
Value::Table(table) => self.config = table,
_ => eprintln!("toml config file should be a table"),
_ => panic!("toml config file should be a table"),
}
}
@ -142,6 +142,7 @@ pub(crate) static SHARED_STATE: LazyLock<State> = LazyLock::new(|| {
app_env = arg.strip_prefix("--env=").unwrap().to_string();
}
}
let mut state = State::new(app_env);
state.load_config();
@ -160,7 +161,7 @@ pub(crate) static SHARED_STATE: LazyLock<State> = LazyLock::new(|| {
.expect("the `postgres` field should be a table");
match ConnectionPool::connect_lazy(postgres) {
Ok(pool) => pools.push(pool),
Err(err) => eprintln!("{err}"),
Err(err) => tracing::error!("{err}"),
}
}
}

View File

@ -2,6 +2,7 @@
name = "zino-derive"
description = "Derived traits for zino."
version = "0.2.1"
rust-version = "1.68"
edition = "2021"
license = "MIT"
homepage = "https://github.com/photino/zino"

View File

@ -2,6 +2,7 @@
name = "zino-model"
description = "Model types for zino."
version = "0.2.1"
rust-version = "1.68"
edition = "2021"
license = "MIT"
homepage = "https://github.com/photino/zino"

View File

@ -2,8 +2,11 @@
name = "zino"
description = "A minimal web framework."
version = "0.2.1"
rust-version = "1.68"
edition = "2021"
license = "MIT"
categories = ["asynchronous", "network-programming", "web-programming::http-server"]
keywords = ["http", "web", "framework"]
homepage = "https://github.com/photino/zino"
repository = "https://github.com/photino/zino"
documentation = "https://docs.rs/zino"
@ -27,7 +30,9 @@ tokio = { version = "1.23.0", features = ["rt-multi-thread", "sync"], optional =
tokio-stream = { version = "0.1.11", features = ["sync"], optional = true }
toml = { version = "0.5.10" }
tower = { version = "0.4.13", optional = true }
tower-http = { version = "0.1.1", features = ["add-extension", "fs"], optional = true }
tower-http = { version = "0.3.5", features = ["full"], optional = true }
tracing = { version = "0.1.37" }
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json", "local-time"] }
[dependencies.zino-core]
path = "../zino-core"

View File

@ -1,15 +1,25 @@
use axum::{
body::{Bytes, Full},
http::{self, StatusCode},
routing, Router, Server,
middleware, routing, Router, Server,
};
use futures::future;
use std::{
collections::HashMap, convert::Infallible, env, io, net::SocketAddr, path::Path, time::Instant,
collections::HashMap,
convert::Infallible,
env, io,
net::SocketAddr,
path::Path,
sync::{Arc, LazyLock},
time::Instant,
};
use tokio::runtime::Builder;
use tower::layer;
use tower_http::services::{ServeDir, ServeFile};
use tower::ServiceBuilder;
use tower_http::{
add_extension::AddExtensionLayer,
compression::CompressionLayer,
services::{ServeDir, ServeFile},
};
use zino_core::{Application, Response, State};
/// An HTTP server cluster for `axum`.
@ -76,7 +86,11 @@ impl Application for AxumCluster {
.build()?
.block_on(async {
let routes = self.routes;
let listeners = State::shared().listeners();
let shared_state = State::shared();
let app_env = shared_state.env();
tracing::info!("load config.{app_env}.toml");
let listeners = shared_state.listeners();
let servers = listeners.iter().map(|listener| {
let mut app = Router::new()
.route_service("/", serve_file_service.clone())
@ -89,25 +103,38 @@ impl Application for AxumCluster {
for (path, route) in &routes {
app = app.nest(path, route.clone());
}
let state = Arc::new(State::default());
app = app
.fallback_service(tower::service_fn(|_| async {
let res = Response::new(StatusCode::NOT_FOUND);
Ok::<http::Response<Full<Bytes>>, Infallible>(res.into())
}))
.layer(layer::layer_fn(
crate::middleware::axum_context::ContextMiddleware::new,
));
.layer(
ServiceBuilder::new()
.layer(LazyLock::force(
&crate::middleware::tower_tracing::TRACING_MIDDLEWARE,
))
.layer(LazyLock::force(
&crate::middleware::tower_cors::CORS_MIDDLEWARE,
))
.layer(middleware::from_fn(
crate::middleware::axum_context::request_context,
))
.layer(AddExtensionLayer::new(state))
.layer(CompressionLayer::new()),
);
let addr = listener
.parse()
.inspect(|addr| println!("listen on {addr}"))
.inspect(|addr| tracing::info!(env = app_env, "listen on {addr}"))
.unwrap_or_else(|_| panic!("invalid socket address: {listener}"));
Server::bind(&addr)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
});
for result in future::join_all(servers).await {
if let Err(err) = result {
eprintln!("server error: {err}");
tracing::error!("server error: {err}");
}
}
});

View File

@ -29,13 +29,13 @@ pub(crate) async fn websocket_handler(
if topic.filter(|&t| t != event_topic).is_none() {
let message = Message::Text(data.to_string());
if let Err(err) = socket.send(message).await {
eprintln!("{err}");
tracing::error!("{err}");
}
}
}
}
}
Err(err) => eprintln!("{err}"),
Err(err) => tracing::error!("{err}"),
}
}
})

View File

@ -1,60 +1,29 @@
use crate::AxumExtractor;
use axum::{
body::Body,
http::{Request, Response},
body::{Body, BoxBody},
http::{Request, Response, StatusCode},
middleware::Next,
};
use futures::future::BoxFuture;
use std::{mem, task};
use tower::Service;
use zino_core::RequestContext;
/// Request context middleware.
#[derive(Debug, Clone)]
pub(crate) struct ContextMiddleware<S> {
inner: S,
}
pub(crate) async fn request_context(
req: Request<Body>,
next: Next<Body>,
) -> Result<Response<BoxBody>, StatusCode> {
let mut req_extractor = AxumExtractor(req);
let ext = match req_extractor.get_context() {
Some(_) => None,
None => {
let mut ctx = req_extractor.new_context();
let original_uri = req_extractor.original_uri().await;
ctx.set_request_path(original_uri.path());
Some(ctx)
}
};
impl<S> ContextMiddleware<S> {
/// Creates a new instance.
#[inline]
pub(crate) fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S, ResBody> Service<Request<Body>> for ContextMiddleware<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let cloned_inner = self.inner.clone();
let mut inner = mem::replace(&mut self.inner, cloned_inner);
Box::pin(async move {
let mut req_extractor = AxumExtractor(req);
let ext = match req_extractor.get_context() {
Some(_) => None,
None => {
let mut ctx = req_extractor.new_context();
let original_uri = req_extractor.original_uri().await;
ctx.set_request_path(original_uri.path());
Some(ctx)
}
};
let mut req = req_extractor.0;
if let Some(ctx) = ext {
req.extensions_mut().insert(ctx);
}
inner.call(req).await
})
let mut req = req_extractor.0;
if let Some(ctx) = ext {
req.extensions_mut().insert(ctx);
}
Ok(next.run(req).await)
}

View File

@ -1,2 +1,8 @@
#[cfg(feature = "axum-server")]
pub(crate) mod axum_context;
#[cfg(feature = "axum-server")]
pub(crate) mod tower_cors;
#[cfg(feature = "axum-server")]
pub(crate) mod tower_tracing;

View File

@ -0,0 +1,72 @@
use std::{sync::LazyLock, time::Duration};
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders};
use zino_core::State;
// CORS middleware.
pub(crate) static CORS_MIDDLEWARE: LazyLock<CorsLayer> = LazyLock::new(|| {
let config = State::shared().config();
match config.get("cors").and_then(|t| t.as_table()) {
Some(cors) => {
let allow_credentials = cors
.get("allow-credentials")
.and_then(|t| t.as_bool())
.unwrap_or(false);
let allow_origin = cors
.get("allow-origin")
.and_then(|t| t.as_array())
.map(|v| {
let origins = v
.iter()
.filter_map(|t| t.as_str().and_then(|s| s.parse().ok()))
.collect::<Vec<_>>();
AllowOrigin::list(origins)
})
.unwrap_or(AllowOrigin::mirror_request());
let allow_methods = cors
.get("allow-methods")
.and_then(|t| t.as_array())
.map(|v| {
let methods = v
.iter()
.filter_map(|t| t.as_str().and_then(|s| s.parse().ok()))
.collect::<Vec<_>>();
AllowMethods::list(methods)
})
.unwrap_or(AllowMethods::mirror_request());
let allow_headers = cors
.get("allow-headers")
.and_then(|t| t.as_array())
.map(|v| {
let header_names = v
.iter()
.filter_map(|t| t.as_str().and_then(|s| s.parse().ok()))
.collect::<Vec<_>>();
AllowHeaders::list(header_names)
})
.unwrap_or(AllowHeaders::mirror_request());
let expose_headers = cors
.get("expose-headers")
.and_then(|t| t.as_array())
.map(|v| {
let header_names = v
.iter()
.filter_map(|t| t.as_str().and_then(|s| s.parse().ok()))
.collect::<Vec<_>>();
ExposeHeaders::list(header_names)
})
.unwrap_or(ExposeHeaders::any());
let max_age = cors
.get("max-age")
.and_then(|t| t.as_integer().and_then(|i| i.try_into().ok()))
.unwrap_or(86400);
CorsLayer::new()
.allow_credentials(allow_credentials)
.allow_origin(allow_origin)
.allow_methods(allow_methods)
.allow_headers(allow_headers)
.expose_headers(expose_headers)
.max_age(Duration::from_secs(max_age))
}
None => CorsLayer::permissive(),
}
});

View File

@ -0,0 +1,100 @@
use std::sync::LazyLock;
use tower_http::{
classify::{SharedClassifier, StatusInRangeAsFailures},
trace::{
DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest,
DefaultOnResponse, TraceLayer,
},
LatencyUnit,
};
use tracing::Level;
use tracing_subscriber::fmt::{time, writer::MakeWriterExt};
use zino_core::State;
// Tracing middleware.
pub(crate) static TRACING_MIDDLEWARE: LazyLock<
TraceLayer<SharedClassifier<StatusInRangeAsFailures>>,
> = LazyLock::new(|| {
let shared_state = State::shared();
let app_env = shared_state.env();
let is_dev = app_env == "dev";
let mut env_filter = if is_dev {
"sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"
} else {
"warn,tower_http=info,zino=info,zino_core=info"
};
let mut display_target = is_dev;
let mut display_filename = false;
let mut display_line_number = false;
let mut display_thread_names = false;
let mut display_span_list = false;
let display_current_span = true;
let include_headers = true;
let config = shared_state.config();
if let Some(tracing) = config.get("tracing").and_then(|t| t.as_table()) {
if let Some(filter) = tracing.get("filter").and_then(|t| t.as_str()) {
env_filter = filter;
}
display_target = tracing
.get("display-target")
.and_then(|t| t.as_bool())
.unwrap_or(is_dev);
display_filename = tracing
.get("display-filename")
.and_then(|t| t.as_bool())
.unwrap_or(false);
display_line_number = tracing
.get("display-line-number")
.and_then(|t| t.as_bool())
.unwrap_or(false);
display_thread_names = tracing
.get("display-thread-names")
.and_then(|t| t.as_bool())
.unwrap_or(false);
display_span_list = tracing
.get("display-span-list")
.and_then(|t| t.as_bool())
.unwrap_or(false);
}
let stderr = std::io::stderr.with_max_level(Level::WARN);
tracing_subscriber::fmt()
.json()
.with_env_filter(env_filter)
.with_target(display_target)
.with_file(display_filename)
.with_line_number(display_line_number)
.with_thread_names(display_thread_names)
.with_span_list(display_span_list)
.with_current_span(display_current_span)
.with_timer(time::LocalTime::rfc_3339())
.map_writer(move |w| stderr.or_else(w))
.init();
let classifier = StatusInRangeAsFailures::new_for_client_and_server_errors();
TraceLayer::new(classifier.into_make_classifier())
.make_span_with(
DefaultMakeSpan::new()
.level(Level::INFO)
.include_headers(include_headers),
)
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_response(
DefaultOnResponse::new()
.level(Level::INFO)
.latency_unit(LatencyUnit::Micros),
)
.on_body_chunk(DefaultOnBodyChunk::new())
.on_eos(
DefaultOnEos::new()
.level(Level::INFO)
.latency_unit(LatencyUnit::Micros),
)
.on_failure(
DefaultOnFailure::new()
.level(Level::ERROR)
.latency_unit(LatencyUnit::Micros),
)
});