Fix crypto

This commit is contained in:
photino 2023-01-06 04:55:37 +08:00
parent ff6f05aed6
commit af815b0ea7
22 changed files with 174 additions and 104 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "axum-app" name = "axum-app"
version = "0.2.1" version = "0.2.2"
rust-version = "1.68" rust-version = "1.68"
edition = "2021" edition = "2021"
publish = false publish = false

View File

@ -1,6 +1,6 @@
name = "data-cube" name = "data-cube"
version = "1.0.2" version = "0.2.2"
[main] [main]
host = "127.0.0.1" host = "127.0.0.1"
@ -22,6 +22,9 @@ namespace = "dc"
[[postgres]] [[postgres]]
host = "localhost" host = "localhost"
port = 5432 port = 5432
user = "postgres"
password = "postgres"
database = "data_cube" database = "data_cube"
user = "postgres"
password = "QAx01wnh1i5ER713zfHmZi6dIUYn/Iq9ag+iUGtvKzEFJFYW"
[tracing]
filter = "sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"

View File

@ -1,6 +1,6 @@
name = "data-cube" name = "data-cube"
version = "1.0.2" version = "0.2.2"
[main] [main]
host = "127.0.0.1" host = "127.0.0.1"
@ -22,9 +22,9 @@ namespace = "dc"
[[postgres]] [[postgres]]
host = "localhost" host = "localhost"
port = 5432 port = 5432
user = "postgres"
password = "postgres"
database = "data_cube" database = "data_cube"
user = "postgres"
password = "G76hTg8T5Aa+SZQFc+0QnsRLo1UOjqpkp/jUQ+lySc8QCt4B"
[tracing] [tracing]
filter = "warn,zino=info,zino_core=info" filter = "sqlx=warn,tower_http=info,zino=info,zino_core=info"

View File

@ -1,7 +1,7 @@
[package] [package]
name = "zino-core" name = "zino-core"
description = "Core types and traits for zino." description = "Core types and traits for zino."
version = "0.2.1" version = "0.2.2"
rust-version = "1.68" rust-version = "1.68"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
@ -25,7 +25,6 @@ http-types = { version = "2.12.0" }
rand = { version = "0.8.5" } rand = { version = "0.8.5" }
serde = { version = "1.0.152", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
serde_json = { version = "1.0.91" } serde_json = { version = "1.0.91" }
sha-1 = { version = "0.10.1" }
sha2 = { version = "0.10.6" } sha2 = { version = "0.10.6" }
sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "postgres", "uuid", "time", "json"] } sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "postgres", "uuid", "time", "json"] }
time = { version = "0.3.17", features = ["local-offset", "parsing", "serde"] } time = { version = "0.3.17", features = ["local-offset", "parsing", "serde"] }

View File

@ -1,6 +1,8 @@
use crate::{DateTime, Map, Validation}; use crate::{DateTime, Map, Validation};
use hmac::{Hmac, Mac}; use hmac::{
use sha1::Sha1; digest::{FixedOutput, KeyInit, MacMarker, Update},
Mac,
};
use std::time::Duration; use std::time::Duration;
mod access_key; mod access_key;
@ -11,7 +13,7 @@ pub use access_key::{AccessKeyId, SecretAccessKey};
pub(crate) use security_token::ParseTokenError; pub(crate) use security_token::ParseTokenError;
pub use security_token::SecurityToken; pub use security_token::SecurityToken;
/// HTTP signature using RFC 2104 HMAC-SHA1. /// HTTP signature using HMAC.
pub struct Authentication { pub struct Authentication {
/// Service name. /// Service name.
service_name: String, service_name: String,
@ -49,7 +51,7 @@ impl Authentication {
accept: None, accept: None,
content_md5: None, content_md5: None,
content_type: None, content_type: None,
date_header: ("Date".to_string(), DateTime::now()), date_header: ("date".to_string(), DateTime::now()),
expires: None, expires: None,
headers: Vec::new(), headers: Vec::new(),
resource: String::new(), resource: String::new(),
@ -74,25 +76,25 @@ impl Authentication {
self.signature = signature; self.signature = signature;
} }
/// Sets the `Accept` header value. /// Sets the `accept` header value.
#[inline] #[inline]
pub fn set_accept(&mut self, accept: impl Into<Option<String>>) { pub fn set_accept(&mut self, accept: impl Into<Option<String>>) {
self.accept = accept.into(); self.accept = accept.into();
} }
/// Sets the `Content-MD5` header value. /// Sets the `content-md5` header value.
#[inline] #[inline]
pub fn set_content_md5(&mut self, content_md5: String) { pub fn set_content_md5(&mut self, content_md5: String) {
self.content_md5 = Some(content_md5); self.content_md5 = Some(content_md5);
} }
/// Sets the `Content-Type` header value. /// Sets the `content-type` header value.
#[inline] #[inline]
pub fn set_content_type(&mut self, content_type: impl Into<Option<String>>) { pub fn set_content_type(&mut self, content_type: impl Into<Option<String>>) {
self.content_type = content_type.into(); self.content_type = content_type.into();
} }
/// Sets the `Date` header value. /// Sets the `date` header value.
#[inline] #[inline]
pub fn set_date_header(&mut self, header_name: String, date: DateTime) { pub fn set_date_header(&mut self, header_name: String, date: DateTime) {
self.date_header = (header_name, date); self.date_header = (header_name, date);
@ -164,7 +166,7 @@ impl Authentication {
self.signature.as_str() self.signature.as_str()
} }
/// Returns an `Authorization` header value. /// Returns an `authorization` header value.
#[inline] #[inline]
pub fn authorization(&self) -> String { pub fn authorization(&self) -> String {
let service_name = self.service_name(); let service_name = self.service_name();
@ -214,7 +216,7 @@ impl Authentication {
} else { } else {
// Date // Date
let date_header = &self.date_header; let date_header = &self.date_header;
let date = if date_header.0.eq_ignore_ascii_case("Date") { let date = if date_header.0.eq_ignore_ascii_case("date") {
date_header.1.to_utc_string() date_header.1.to_utc_string()
} else { } else {
"".to_string() "".to_string()
@ -238,16 +240,22 @@ impl Authentication {
} }
/// Generates a signature with the secret access key. /// Generates a signature with the secret access key.
pub fn sign_with(&self, secret_access_key: SecretAccessKey) -> String { pub fn sign_with<H>(&self, secret_access_key: SecretAccessKey) -> String
where
H: FixedOutput + KeyInit + MacMarker + Update,
{
let string_to_sign = self.string_to_sign(); let string_to_sign = self.string_to_sign();
let mut mac = Hmac::<Sha1>::new_from_slice(secret_access_key.as_ref()) let mut mac =
.expect("HMAC can take key of any size"); H::new_from_slice(secret_access_key.as_ref()).expect("HMAC can take key of any size");
mac.update(string_to_sign.as_ref()); mac.update(string_to_sign.as_ref());
base64::encode(mac.finalize().into_bytes()) base64::encode(mac.finalize().into_bytes())
} }
/// Validates the signature using the secret access key. /// Validates the signature using the secret access key.
pub fn validate_with(&self, secret_access_key: SecretAccessKey) -> Validation { pub fn validate_with<H>(&self, secret_access_key: SecretAccessKey) -> Validation
where
H: FixedOutput + KeyInit + MacMarker + Update,
{
let mut validation = Validation::new(); let mut validation = Validation::new();
let current = DateTime::now(); let current = DateTime::now();
let date = self.date_header.1; let date = self.date_header.1;
@ -264,7 +272,7 @@ impl Authentication {
} }
let signature = self.signature(); let signature = self.signature();
if signature.is_empty() || self.sign_with(secret_access_key) == signature { if signature.is_empty() || self.sign_with::<H>(secret_access_key) == signature {
validation.record_fail("signature", "invalid signature"); validation.record_fail("signature", "invalid signature");
} }
validation validation

View File

@ -49,9 +49,11 @@ impl SecurityToken {
let key = key.as_ref(); let key = key.as_ref();
let grantor_id = id.into(); let grantor_id = id.into();
let timestamp = expires.timestamp(); let timestamp = expires.timestamp();
let assignee_id = base64::encode(crypto::encrypt(key, grantor_id.as_ref())).into(); let grantor_id_cipher = crypto::encrypt(key, grantor_id.as_ref()).unwrap_or_default();
let assignee_id = base64::encode(grantor_id_cipher).into();
let authorization = format!("{assignee_id}:{timestamp}"); let authorization = format!("{assignee_id}:{timestamp}");
let token = base64::encode(crypto::encrypt(key, authorization.as_ref())); let authorization_cipher = crypto::encrypt(key, authorization.as_ref()).unwrap_or_default();
let token = base64::encode(authorization_cipher);
Self { Self {
grantor_id, grantor_id,
assignee_id, assignee_id,
@ -65,13 +67,17 @@ impl SecurityToken {
use ParseTokenError::*; use ParseTokenError::*;
match base64::decode(&token) { match base64::decode(&token) {
Ok(data) => { Ok(data) => {
let authorization = crypto::decrypt(key, &data); let authorization = crypto::decrypt(key, &data)
.map_err(|_| DecodeError("fail to decrypt authorization".to_string()))?;
if let Some((assignee_id, timestamp)) = authorization.split_once(':') { if let Some((assignee_id, timestamp)) = authorization.split_once(':') {
match timestamp.parse() { match timestamp.parse() {
Ok(secs) => { Ok(secs) => {
if DateTime::now().timestamp() <= secs { if DateTime::now().timestamp() <= secs {
let expires = DateTime::from_timestamp(secs); let expires = DateTime::from_timestamp(secs);
let grantor_id = crypto::decrypt(key, assignee_id.as_ref()); let grantor_id = crypto::decrypt(key, assignee_id.as_ref())
.map_err(|_| {
DecodeError("fail to decrypt grantor id".to_string())
})?;
Ok(Self { Ok(Self {
grantor_id: grantor_id.into(), grantor_id: grantor_id.into(),
assignee_id: assignee_id.into(), assignee_id: assignee_id.into(),

View File

@ -2,35 +2,42 @@
use aes_gcm_siv::{ use aes_gcm_siv::{
aead::{generic_array::GenericArray, Aead}, aead::{generic_array::GenericArray, Aead},
Aes256GcmSiv, KeyInit, Aes256GcmSiv, Error, KeyInit, Nonce,
}; };
use rand::Rng; use rand::Rng;
/// Encrypts the plaintext using AES-GCM-SIV. /// Encrypts the plaintext using AES-GCM-SIV.
pub(crate) fn encrypt(key: &[u8], plaintext: &[u8]) -> Vec<u8> { pub(crate) fn encrypt(key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, Error> {
let key = GenericArray::from_slice(key); const KEY_SIZE: usize = 32;
let cipher = Aes256GcmSiv::new(key); const NONCE_SIZE: usize = 12;
let key_padding = [key, &[0u8; KEY_SIZE]].concat();
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key_padding[0..KEY_SIZE]));
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let mut bytes = [0u8; 12]; let mut bytes = [0u8; NONCE_SIZE];
rng.fill(&mut bytes); rng.fill(&mut bytes);
let nonce = GenericArray::from_slice(&bytes); let nonce = Nonce::from_slice(&bytes);
let ciphertext = cipher let mut ciphertext = cipher.encrypt(nonce, plaintext)?;
.encrypt(nonce, plaintext) ciphertext.extend_from_slice(&bytes);
.expect("encryption failure"); Ok(ciphertext)
bytes.copy_from_slice(&ciphertext);
bytes.to_vec()
} }
/// Decrypts the data using AES-GCM-SIV. /// Decrypts the data using AES-GCM-SIV.
pub(crate) fn decrypt(key: &[u8], data: &[u8]) -> String { pub(crate) fn decrypt(key: &[u8], data: &[u8]) -> Result<String, Error> {
let key = GenericArray::from_slice(key); const KEY_SIZE: usize = 32;
let cipher = Aes256GcmSiv::new(key); const NONCE_SIZE: usize = 12;
let (bytes, ciphertext) = data.split_at(92);
if data.len() <= NONCE_SIZE {
return Err(Error);
}
let key_padding = [key, &[0u8; KEY_SIZE]].concat();
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key_padding[0..KEY_SIZE]));
let (ciphertext, bytes) = data.split_at(data.len() - NONCE_SIZE);
let nonce = GenericArray::from_slice(bytes); let nonce = GenericArray::from_slice(bytes);
let plaintext = cipher let plaintext = cipher.decrypt(nonce, ciphertext)?;
.decrypt(nonce, ciphertext) Ok(String::from_utf8_lossy(&plaintext).to_string())
.expect("decryption failure");
String::from_utf8_lossy(&plaintext).to_string()
} }

View File

@ -26,6 +26,29 @@ pub struct ConnectionPool {
} }
impl ConnectionPool { impl ConnectionPool {
/// Encrypts the database password in the config.
pub fn encrypt_password(config: &Table) -> Option<String> {
let user = config
.get("user")
.expect("the `postgres.user` field is missing")
.as_str()
.expect("the `postgres.user` field should be a str");
let database = config
.get("database")
.expect("the `postgres.database` field is missing")
.as_str()
.expect("the `postgres.database` field should be a str");
let password = config
.get("password")
.expect("the `postgres.password` field is missing")
.as_str()
.expect("the `postgres.password` field should be a str");
let key = format!("{user}@{database}");
crate::crypto::encrypt(key.as_bytes(), password.as_bytes())
.ok()
.map(base64::encode)
}
/// Connects lazily to the database according to the config. /// Connects lazily to the database according to the config.
pub fn connect_lazy(config: &Table) -> Result<Self, Error> { pub fn connect_lazy(config: &Table) -> Result<Self, Error> {
let host = config let host = config
@ -43,17 +66,24 @@ impl ConnectionPool {
.expect("the `postgres.user` field is missing") .expect("the `postgres.user` field is missing")
.as_str() .as_str()
.expect("the `postgres.user` field should be a str"); .expect("the `postgres.user` field should be a str");
let password = config
.get("password")
.expect("the `postgres.password` field is missing")
.as_str()
.expect("the `postgres.password` field should be a str");
let database = config let database = config
.get("database") .get("database")
.expect("the `postgres.database` field is missing") .expect("the `postgres.database` field is missing")
.as_str() .as_str()
.expect("the `postgres.database` field should be a str"); .expect("the `postgres.database` field should be a str");
let connection_string = format!("postgres://{user}:{password}@{host}:{port}/{database}",); let mut password = config
.get("password")
.expect("the `postgres.password` field is missing")
.as_str()
.expect("the `postgres.password` field should be a str");
if let Ok(data) = base64::decode(password) {
let key = format!("{user}@{database}");
if let Ok(plaintext) = crate::crypto::decrypt(key.as_bytes(), &data) {
password = plaintext.leak();
}
}
let connection_string = format!("postgres://{user}:{password}@{host}:{port}/{database}");
let max_connections = config let max_connections = config
.get("max-connections") .get("max-connections")
.and_then(|t| t.as_integer()) .and_then(|t| t.as_integer())

View File

@ -39,17 +39,16 @@ pub trait Schema: 'static + Send + Sync + Model {
/// Returns the model namespace. /// Returns the model namespace.
#[inline] #[inline]
fn model_namespace() -> &'static str { fn model_namespace() -> &'static str {
let namespace = [*NAMESPACE_PREFIX, Self::TYPE_NAME].join(":"); [*NAMESPACE_PREFIX, Self::TYPE_NAME].join(":").leak()
Box::leak(namespace.into_boxed_str())
} }
/// Returns the table name. /// Returns the table name.
#[inline] #[inline]
fn table_name() -> &'static str { fn table_name() -> &'static str {
let table_name = [*NAMESPACE_PREFIX, Self::TYPE_NAME] [*NAMESPACE_PREFIX, Self::TYPE_NAME]
.join("_") .join("_")
.replace(':', "_"); .replace(':', "_")
Box::leak(table_name.into_boxed_str()) .leak()
} }
/// Gets a column for the field. /// Gets a column for the field.

View File

@ -3,6 +3,7 @@
#![feature(async_fn_in_trait)] #![feature(async_fn_in_trait)]
#![feature(iter_intersperse)] #![feature(iter_intersperse)]
#![feature(once_cell)] #![feature(once_cell)]
#![feature(string_leak)]
mod application; mod application;
mod authentication; mod authentication;

View File

@ -20,11 +20,11 @@ pub struct Context {
impl Context { impl Context {
/// Creates a new instance. /// Creates a new instance.
pub fn new() -> Self { pub fn new(request_id: Uuid) -> Self {
Self { Self {
start_time: Instant::now(), start_time: Instant::now(),
request_path: String::new(), request_path: String::new(),
request_id: Uuid::new_v4(), request_id,
trace_id: Uuid::nil(), trace_id: Uuid::nil(),
session_id: None, session_id: None,
} }
@ -84,9 +84,3 @@ impl Context {
self.session_id.as_deref() self.session_id.as_deref()
} }
} }
impl Default for Context {
fn default() -> Self {
Self::new()
}
}

View File

@ -39,16 +39,20 @@ pub trait RequestContext {
/// Creates a new request context. /// Creates a new request context.
fn new_context(&self) -> Context { fn new_context(&self) -> Context {
let request_id = self
.get_header("x-request-id")
.and_then(|s| s.parse().ok())
.unwrap_or(Uuid::new_v4());
let trace_context = self.trace_context(); let trace_context = self.trace_context();
let trace_id = trace_context.map_or(Uuid::nil(), |t| Uuid::from_u128(t.trace_id())); let trace_id = trace_context.map_or(Uuid::nil(), |t| Uuid::from_u128(t.trace_id()));
let query = self.parse_query().unwrap_or_default(); let query = self.parse_query().unwrap_or_default();
let session_id = Validation::parse_string(query.get("session_id")).or_else(|| { let session_id = Validation::parse_string(query.get("session_id")).or_else(|| {
self.get_header("Session-Id").and_then(|header| { self.get_header("session-id").and_then(|header| {
// Session IDs have the form: SID:type:realm:identifier[-thread][:count] // Session IDs have the form: SID:type:realm:identifier[-thread][:count]
header.split(':').nth(3).map(|s| s.to_string()) header.split(':').nth(3).map(|s| s.to_string())
}) })
}); });
let mut ctx = Context::new(); let mut ctx = Context::new(request_id);
ctx.set_trace_id(trace_id); ctx.set_trace_id(trace_id);
ctx.set_session_id(session_id); ctx.set_session_id(session_id);
ctx ctx
@ -136,7 +140,7 @@ pub trait RequestContext {
if !validation.is_success() { if !validation.is_success() {
return Err(validation); return Err(validation);
} }
} else if let Some(authorization) = self.get_header("Authorization") { } else if let Some(authorization) = self.get_header("authorization") {
if let Some((service_name, token)) = authorization.split_once(' ') { if let Some((service_name, token)) = authorization.split_once(' ') {
authentication.set_service_name(service_name); authentication.set_service_name(service_name);
if let Some((access_key_id, signature)) = token.split_once(':') { if let Some((access_key_id, signature)) = token.split_once(':') {
@ -152,16 +156,16 @@ pub trait RequestContext {
return Err(validation); return Err(validation);
} }
} }
if let Some(content_md5) = self.get_header("Content-MD5") { if let Some(content_md5) = self.get_header("content-md5") {
authentication.set_content_md5(content_md5.to_string()); authentication.set_content_md5(content_md5.to_string());
} }
if let Some(date) = self.get_header("Date") { if let Some(date) = self.get_header("date") {
match DateTime::parse_utc_str(date) { match DateTime::parse_utc_str(date) {
Ok(date) => { Ok(date) => {
let current = DateTime::now(); let current = DateTime::now();
let max_tolerance = Duration::from_secs(900); let max_tolerance = Duration::from_secs(900);
if date >= current - max_tolerance && date <= current + max_tolerance { if date >= current - max_tolerance && date <= current + max_tolerance {
authentication.set_date_header("Date".to_string(), date); authentication.set_date_header("date".to_string(), date);
} else { } else {
validation.record_fail("date", "untrusted date"); validation.record_fail("date", "untrusted date");
} }
@ -172,7 +176,7 @@ pub trait RequestContext {
} }
} }
} }
authentication.set_content_type(self.get_header("Content-Type").map(|t| t.to_string())); authentication.set_content_type(self.get_header("content-type").map(|t| t.to_string()));
authentication.set_resource(self.request_path().to_string(), None); authentication.set_resource(self.request_path().to_string(), None);
Ok(authentication) Ok(authentication)
} }

View File

@ -343,12 +343,12 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
HeaderValue::from_str(content_type.as_str()).unwrap(), HeaderValue::from_str(content_type.as_str()).unwrap(),
) )
.body(Full::from(bytes)) .body(Full::from(bytes))
.unwrap(), .unwrap_or_default(),
Err(err) => http::Response::builder() Err(err) => http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR) .status(http::StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/plain") .header(header::CONTENT_TYPE, "text/plain")
.body(Full::from(err.to_string())) .body(Full::from(err.to_string()))
.unwrap(), .unwrap_or_default(),
}, },
None => match serde_json::to_vec(&response) { None => match serde_json::to_vec(&response) {
Ok(bytes) => { Ok(bytes) => {
@ -361,13 +361,13 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
.status(response.status_code) .status(response.status_code)
.header(header::CONTENT_TYPE, HeaderValue::from_static(content_type)) .header(header::CONTENT_TYPE, HeaderValue::from_static(content_type))
.body(Full::from(bytes)) .body(Full::from(bytes))
.unwrap() .unwrap_or_default()
} }
Err(err) => http::Response::builder() Err(err) => http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR) .status(http::StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/plain") .header(header::CONTENT_TYPE, "text/plain")
.body(Full::from(err.to_string())) .body(Full::from(err.to_string()))
.unwrap(), .unwrap_or_default(),
}, },
}; };
let trace_context = match response.trace_context { let trace_context = match response.trace_context {
@ -384,6 +384,14 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
HeaderName::from_static("server-timing"), HeaderName::from_static("server-timing"),
HeaderValue::from_str(response.server_timing.value().as_str()).unwrap(), HeaderValue::from_str(response.server_timing.value().as_str()).unwrap(),
); );
let request_id = response.request_id;
if !request_id.is_nil() {
res.headers_mut().insert(
HeaderName::from_static("x-request-id"),
HeaderValue::from_str(request_id.to_string().as_str()).unwrap(),
);
}
res res
} }
} }

View File

@ -1,7 +1,7 @@
[package] [package]
name = "zino-derive" name = "zino-derive"
description = "Derived traits for zino." description = "Derived traits for zino."
version = "0.2.1" version = "0.2.2"
rust-version = "1.68" rust-version = "1.68"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
@ -20,4 +20,4 @@ syn = { version = "1.0.107", features = ["full", "extra-traits"] }
[dependencies.zino-core] [dependencies.zino-core]
path = "../zino-core" path = "../zino-core"
version = "0.2.1" version = "0.2.2"

View File

@ -1,7 +1,7 @@
[package] [package]
name = "zino-model" name = "zino-model"
description = "Model types for zino." description = "Model types for zino."
version = "0.2.1" version = "0.2.2"
rust-version = "1.68" rust-version = "1.68"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
@ -15,8 +15,8 @@ serde = { version = "1.0.152", features = ["derive"] }
[dependencies.zino-core] [dependencies.zino-core]
path = "../zino-core" path = "../zino-core"
version = "0.2.1" version = "0.2.2"
[dependencies.zino-derive] [dependencies.zino-derive]
path = "../zino-derive" path = "../zino-derive"
version = "0.2.1" version = "0.2.2"

View File

@ -1,7 +1,7 @@
[package] [package]
name = "zino" name = "zino"
description = "A minimal web framework." description = "A minimal web framework."
version = "0.2.1" version = "0.2.2"
rust-version = "1.68" rust-version = "1.68"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
@ -29,11 +29,11 @@ serde_urlencoded = { version = "0.7.1" }
tokio = { version = "1.23.0", features = ["rt-multi-thread", "sync"], optional = true } tokio = { version = "1.23.0", features = ["rt-multi-thread", "sync"], optional = true }
tokio-stream = { version = "0.1.11", features = ["sync"], optional = true } tokio-stream = { version = "0.1.11", features = ["sync"], optional = true }
toml = { version = "0.5.10" } toml = { version = "0.5.10" }
tower = { version = "0.4.13", optional = true } tower = { version = "0.4.13", features = ["timeout"], optional = true }
tower-http = { version = "0.3.5", features = ["full"], optional = true } tower-http = { version = "0.3.5", features = ["full"], optional = true }
tracing = { version = "0.1.37" } tracing = { version = "0.1.37" }
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json", "local-time"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json", "local-time"] }
[dependencies.zino-core] [dependencies.zino-core]
path = "../zino-core" path = "../zino-core"
version = "0.2.1" version = "0.2.2"

View File

@ -1,7 +1,9 @@
use axum::{ use axum::{
body::{Bytes, Full}, body::{Bytes, Full},
error_handling::HandleErrorLayer,
extract::{rejection::LengthLimitError, DefaultBodyLimit},
http::{self, StatusCode}, http::{self, StatusCode},
middleware, routing, Router, Server, middleware, routing, BoxError, Router, Server,
}; };
use futures::future; use futures::future;
use std::{ use std::{
@ -11,10 +13,13 @@ use std::{
net::SocketAddr, net::SocketAddr,
path::Path, path::Path,
sync::{Arc, LazyLock}, sync::{Arc, LazyLock},
time::Instant, time::{Duration, Instant},
}; };
use tokio::runtime::Builder; use tokio::runtime::Builder;
use tower::ServiceBuilder; use tower::{
timeout::{error::Elapsed, TimeoutLayer},
ServiceBuilder,
};
use tower_http::{ use tower_http::{
add_extension::AddExtensionLayer, add_extension::AddExtensionLayer,
compression::CompressionLayer, compression::CompressionLayer,
@ -121,8 +126,21 @@ impl Application for AxumCluster {
.layer(middleware::from_fn( .layer(middleware::from_fn(
crate::middleware::axum_context::request_context, crate::middleware::axum_context::request_context,
)) ))
.layer(DefaultBodyLimit::disable())
.layer(AddExtensionLayer::new(state)) .layer(AddExtensionLayer::new(state))
.layer(CompressionLayer::new()), .layer(CompressionLayer::new())
.layer(HandleErrorLayer::new(|err: BoxError| async move {
let status_code = if err.is::<Elapsed>() {
StatusCode::REQUEST_TIMEOUT
} else if err.is::<LengthLimitError>() {
StatusCode::PAYLOAD_TOO_LARGE
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
let res = Response::new(status_code);
Ok::<http::Response<Full<Bytes>>, Infallible>(res.into())
}))
.layer(TimeoutLayer::new(Duration::from_secs(10))),
); );
let addr = listener let addr = listener

View File

@ -18,7 +18,7 @@ pub(crate) async fn websocket_handler(
let source = subscription.source(); let source = subscription.source();
let topic = subscription.topic(); let topic = subscription.topic();
while let Some(Ok(Message::Text(message))) = socket.recv().await { while let Some(Ok(Message::Text(message))) = socket.recv().await {
let data = Box::leak(message.into_boxed_str()); let data = message.leak();
match serde_json::from_str::<'_, CloudEvent>(data) { match serde_json::from_str::<'_, CloudEvent>(data) {
Ok(event) => { Ok(event) => {
let event_session_id = event.session_id(); let event_session_id = event.session_id();

View File

@ -3,6 +3,7 @@
#![feature(async_fn_in_trait)] #![feature(async_fn_in_trait)]
#![feature(once_cell)] #![feature(once_cell)]
#![feature(result_option_inspect)] #![feature(result_option_inspect)]
#![feature(string_leak)]
mod channel; mod channel;
mod cluster; mod cluster;

View File

@ -1,4 +1,3 @@
use crate::AxumExtractor;
use axum::{ use axum::{
body::{Body, BoxBody}, body::{Body, BoxBody},
http::{Request, Response, StatusCode}, http::{Request, Response, StatusCode},
@ -10,7 +9,7 @@ pub(crate) async fn request_context(
req: Request<Body>, req: Request<Body>,
next: Next<Body>, next: Next<Body>,
) -> Result<Response<BoxBody>, StatusCode> { ) -> Result<Response<BoxBody>, StatusCode> {
let mut req_extractor = AxumExtractor(req); let mut req_extractor = crate::AxumExtractor(req);
let ext = match req_extractor.get_context() { let ext = match req_extractor.get_context() {
Some(_) => None, Some(_) => None,
None => { None => {

View File

@ -22,7 +22,7 @@ pub(crate) static TRACING_MIDDLEWARE: LazyLock<
let mut env_filter = if is_dev { let mut env_filter = if is_dev {
"sqlx=trace,tower_http=trace,zino=trace,zino_core=trace" "sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"
} else { } else {
"warn,tower_http=info,zino=info,zino_core=info" "sqlx=warn,tower_http=info,zino=info,zino_core=info"
}; };
let mut display_target = is_dev; let mut display_target = is_dev;
let mut display_filename = false; let mut display_filename = false;

View File

@ -16,7 +16,7 @@ use toml::value::Table;
use zino_core::{CloudEvent, Context, Map, Rejection, RequestContext, State, Validation}; use zino_core::{CloudEvent, Context, Map, Rejection, RequestContext, State, Validation};
/// An HTTP request extractor for `axum`. /// An HTTP request extractor for `axum`.
pub struct AxumExtractor<T>(pub T); pub struct AxumExtractor<T>(pub(crate) T);
impl<T> Deref for AxumExtractor<T> { impl<T> Deref for AxumExtractor<T> {
type Target = T; type Target = T;
@ -70,7 +70,7 @@ impl RequestContext for AxumExtractor<Request<Body>> {
async fn parse_body(&mut self) -> Result<Map, Validation> { async fn parse_body(&mut self) -> Result<Map, Validation> {
let form_urlencoded = self let form_urlencoded = self
.get_header("Content-Type") .get_header("content-type")
.unwrap_or("application/x-www-form-urlencoded") .unwrap_or("application/x-www-form-urlencoded")
.starts_with("application/x-www-form-urlencoded"); .starts_with("application/x-www-form-urlencoded");
let body = self.body_mut(); let body = self.body_mut();
@ -109,14 +109,7 @@ impl FromRequest<(), Body> for AxumExtractor<Request<Body>> {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: Request<Body>, _state: &()) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<Body>, _state: &()) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts(); Ok(AxumExtractor(req))
let mut request = Request::new(body);
*request.method_mut() = parts.method;
*request.uri_mut() = parts.uri;
*request.version_mut() = parts.version;
*request.headers_mut() = parts.headers;
*request.extensions_mut() = parts.extensions;
Ok(AxumExtractor(request))
} }
} }