mirror of https://github.com/zino-rs/zino
Fix crypto
This commit is contained in:
parent
ff6f05aed6
commit
af815b0ea7
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "axum-app"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
rust-version = "1.68"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
name = "data-cube"
|
||||
version = "1.0.2"
|
||||
version = "0.2.2"
|
||||
|
||||
[main]
|
||||
host = "127.0.0.1"
|
||||
|
@ -22,6 +22,9 @@ namespace = "dc"
|
|||
[[postgres]]
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
user = "postgres"
|
||||
password = "postgres"
|
||||
database = "data_cube"
|
||||
user = "postgres"
|
||||
password = "QAx01wnh1i5ER713zfHmZi6dIUYn/Iq9ag+iUGtvKzEFJFYW"
|
||||
|
||||
[tracing]
|
||||
filter = "sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
name = "data-cube"
|
||||
version = "1.0.2"
|
||||
version = "0.2.2"
|
||||
|
||||
[main]
|
||||
host = "127.0.0.1"
|
||||
|
@ -22,9 +22,9 @@ namespace = "dc"
|
|||
[[postgres]]
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
user = "postgres"
|
||||
password = "postgres"
|
||||
database = "data_cube"
|
||||
user = "postgres"
|
||||
password = "G76hTg8T5Aa+SZQFc+0QnsRLo1UOjqpkp/jUQ+lySc8QCt4B"
|
||||
|
||||
[tracing]
|
||||
filter = "warn,zino=info,zino_core=info"
|
||||
filter = "sqlx=warn,tower_http=info,zino=info,zino_core=info"
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "zino-core"
|
||||
description = "Core types and traits for zino."
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
rust-version = "1.68"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
|
@ -25,7 +25,6 @@ http-types = { version = "2.12.0" }
|
|||
rand = { version = "0.8.5" }
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde_json = { version = "1.0.91" }
|
||||
sha-1 = { version = "0.10.1" }
|
||||
sha2 = { version = "0.10.6" }
|
||||
sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "postgres", "uuid", "time", "json"] }
|
||||
time = { version = "0.3.17", features = ["local-offset", "parsing", "serde"] }
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::{DateTime, Map, Validation};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha1::Sha1;
|
||||
use hmac::{
|
||||
digest::{FixedOutput, KeyInit, MacMarker, Update},
|
||||
Mac,
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
mod access_key;
|
||||
|
@ -11,7 +13,7 @@ pub use access_key::{AccessKeyId, SecretAccessKey};
|
|||
pub(crate) use security_token::ParseTokenError;
|
||||
pub use security_token::SecurityToken;
|
||||
|
||||
/// HTTP signature using RFC 2104 HMAC-SHA1.
|
||||
/// HTTP signature using HMAC.
|
||||
pub struct Authentication {
|
||||
/// Service name.
|
||||
service_name: String,
|
||||
|
@ -49,7 +51,7 @@ impl Authentication {
|
|||
accept: None,
|
||||
content_md5: None,
|
||||
content_type: None,
|
||||
date_header: ("Date".to_string(), DateTime::now()),
|
||||
date_header: ("date".to_string(), DateTime::now()),
|
||||
expires: None,
|
||||
headers: Vec::new(),
|
||||
resource: String::new(),
|
||||
|
@ -74,25 +76,25 @@ impl Authentication {
|
|||
self.signature = signature;
|
||||
}
|
||||
|
||||
/// Sets the `Accept` header value.
|
||||
/// Sets the `accept` header value.
|
||||
#[inline]
|
||||
pub fn set_accept(&mut self, accept: impl Into<Option<String>>) {
|
||||
self.accept = accept.into();
|
||||
}
|
||||
|
||||
/// Sets the `Content-MD5` header value.
|
||||
/// Sets the `content-md5` header value.
|
||||
#[inline]
|
||||
pub fn set_content_md5(&mut self, content_md5: String) {
|
||||
self.content_md5 = Some(content_md5);
|
||||
}
|
||||
|
||||
/// Sets the `Content-Type` header value.
|
||||
/// Sets the `content-type` header value.
|
||||
#[inline]
|
||||
pub fn set_content_type(&mut self, content_type: impl Into<Option<String>>) {
|
||||
self.content_type = content_type.into();
|
||||
}
|
||||
|
||||
/// Sets the `Date` header value.
|
||||
/// Sets the `date` header value.
|
||||
#[inline]
|
||||
pub fn set_date_header(&mut self, header_name: String, date: DateTime) {
|
||||
self.date_header = (header_name, date);
|
||||
|
@ -164,7 +166,7 @@ impl Authentication {
|
|||
self.signature.as_str()
|
||||
}
|
||||
|
||||
/// Returns an `Authorization` header value.
|
||||
/// Returns an `authorization` header value.
|
||||
#[inline]
|
||||
pub fn authorization(&self) -> String {
|
||||
let service_name = self.service_name();
|
||||
|
@ -214,7 +216,7 @@ impl Authentication {
|
|||
} else {
|
||||
// Date
|
||||
let date_header = &self.date_header;
|
||||
let date = if date_header.0.eq_ignore_ascii_case("Date") {
|
||||
let date = if date_header.0.eq_ignore_ascii_case("date") {
|
||||
date_header.1.to_utc_string()
|
||||
} else {
|
||||
"".to_string()
|
||||
|
@ -238,16 +240,22 @@ impl Authentication {
|
|||
}
|
||||
|
||||
/// Generates a signature with the secret access key.
|
||||
pub fn sign_with(&self, secret_access_key: SecretAccessKey) -> String {
|
||||
pub fn sign_with<H>(&self, secret_access_key: SecretAccessKey) -> String
|
||||
where
|
||||
H: FixedOutput + KeyInit + MacMarker + Update,
|
||||
{
|
||||
let string_to_sign = self.string_to_sign();
|
||||
let mut mac = Hmac::<Sha1>::new_from_slice(secret_access_key.as_ref())
|
||||
.expect("HMAC can take key of any size");
|
||||
let mut mac =
|
||||
H::new_from_slice(secret_access_key.as_ref()).expect("HMAC can take key of any size");
|
||||
mac.update(string_to_sign.as_ref());
|
||||
base64::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
/// Validates the signature using the secret access key.
|
||||
pub fn validate_with(&self, secret_access_key: SecretAccessKey) -> Validation {
|
||||
pub fn validate_with<H>(&self, secret_access_key: SecretAccessKey) -> Validation
|
||||
where
|
||||
H: FixedOutput + KeyInit + MacMarker + Update,
|
||||
{
|
||||
let mut validation = Validation::new();
|
||||
let current = DateTime::now();
|
||||
let date = self.date_header.1;
|
||||
|
@ -264,7 +272,7 @@ impl Authentication {
|
|||
}
|
||||
|
||||
let signature = self.signature();
|
||||
if signature.is_empty() || self.sign_with(secret_access_key) == signature {
|
||||
if signature.is_empty() || self.sign_with::<H>(secret_access_key) == signature {
|
||||
validation.record_fail("signature", "invalid signature");
|
||||
}
|
||||
validation
|
||||
|
|
|
@ -49,9 +49,11 @@ impl SecurityToken {
|
|||
let key = key.as_ref();
|
||||
let grantor_id = id.into();
|
||||
let timestamp = expires.timestamp();
|
||||
let assignee_id = base64::encode(crypto::encrypt(key, grantor_id.as_ref())).into();
|
||||
let grantor_id_cipher = crypto::encrypt(key, grantor_id.as_ref()).unwrap_or_default();
|
||||
let assignee_id = base64::encode(grantor_id_cipher).into();
|
||||
let authorization = format!("{assignee_id}:{timestamp}");
|
||||
let token = base64::encode(crypto::encrypt(key, authorization.as_ref()));
|
||||
let authorization_cipher = crypto::encrypt(key, authorization.as_ref()).unwrap_or_default();
|
||||
let token = base64::encode(authorization_cipher);
|
||||
Self {
|
||||
grantor_id,
|
||||
assignee_id,
|
||||
|
@ -65,13 +67,17 @@ impl SecurityToken {
|
|||
use ParseTokenError::*;
|
||||
match base64::decode(&token) {
|
||||
Ok(data) => {
|
||||
let authorization = crypto::decrypt(key, &data);
|
||||
let authorization = crypto::decrypt(key, &data)
|
||||
.map_err(|_| DecodeError("fail to decrypt authorization".to_string()))?;
|
||||
if let Some((assignee_id, timestamp)) = authorization.split_once(':') {
|
||||
match timestamp.parse() {
|
||||
Ok(secs) => {
|
||||
if DateTime::now().timestamp() <= secs {
|
||||
let expires = DateTime::from_timestamp(secs);
|
||||
let grantor_id = crypto::decrypt(key, assignee_id.as_ref());
|
||||
let grantor_id = crypto::decrypt(key, assignee_id.as_ref())
|
||||
.map_err(|_| {
|
||||
DecodeError("fail to decrypt grantor id".to_string())
|
||||
})?;
|
||||
Ok(Self {
|
||||
grantor_id: grantor_id.into(),
|
||||
assignee_id: assignee_id.into(),
|
||||
|
|
|
@ -2,35 +2,42 @@
|
|||
|
||||
use aes_gcm_siv::{
|
||||
aead::{generic_array::GenericArray, Aead},
|
||||
Aes256GcmSiv, KeyInit,
|
||||
Aes256GcmSiv, Error, KeyInit, Nonce,
|
||||
};
|
||||
use rand::Rng;
|
||||
|
||||
/// Encrypts the plaintext using AES-GCM-SIV.
|
||||
pub(crate) fn encrypt(key: &[u8], plaintext: &[u8]) -> Vec<u8> {
|
||||
let key = GenericArray::from_slice(key);
|
||||
let cipher = Aes256GcmSiv::new(key);
|
||||
pub(crate) fn encrypt(key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
const KEY_SIZE: usize = 32;
|
||||
const NONCE_SIZE: usize = 12;
|
||||
|
||||
let key_padding = [key, &[0u8; KEY_SIZE]].concat();
|
||||
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key_padding[0..KEY_SIZE]));
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut bytes = [0u8; 12];
|
||||
let mut bytes = [0u8; NONCE_SIZE];
|
||||
rng.fill(&mut bytes);
|
||||
|
||||
let nonce = GenericArray::from_slice(&bytes);
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, plaintext)
|
||||
.expect("encryption failure");
|
||||
bytes.copy_from_slice(&ciphertext);
|
||||
bytes.to_vec()
|
||||
let nonce = Nonce::from_slice(&bytes);
|
||||
let mut ciphertext = cipher.encrypt(nonce, plaintext)?;
|
||||
ciphertext.extend_from_slice(&bytes);
|
||||
Ok(ciphertext)
|
||||
}
|
||||
|
||||
/// Decrypts the data using AES-GCM-SIV.
|
||||
pub(crate) fn decrypt(key: &[u8], data: &[u8]) -> String {
|
||||
let key = GenericArray::from_slice(key);
|
||||
let cipher = Aes256GcmSiv::new(key);
|
||||
let (bytes, ciphertext) = data.split_at(92);
|
||||
pub(crate) fn decrypt(key: &[u8], data: &[u8]) -> Result<String, Error> {
|
||||
const KEY_SIZE: usize = 32;
|
||||
const NONCE_SIZE: usize = 12;
|
||||
|
||||
if data.len() <= NONCE_SIZE {
|
||||
return Err(Error);
|
||||
}
|
||||
|
||||
let key_padding = [key, &[0u8; KEY_SIZE]].concat();
|
||||
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key_padding[0..KEY_SIZE]));
|
||||
|
||||
let (ciphertext, bytes) = data.split_at(data.len() - NONCE_SIZE);
|
||||
let nonce = GenericArray::from_slice(bytes);
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext)
|
||||
.expect("decryption failure");
|
||||
String::from_utf8_lossy(&plaintext).to_string()
|
||||
let plaintext = cipher.decrypt(nonce, ciphertext)?;
|
||||
Ok(String::from_utf8_lossy(&plaintext).to_string())
|
||||
}
|
||||
|
|
|
@ -26,6 +26,29 @@ pub struct ConnectionPool {
|
|||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
/// Encrypts the database password in the config.
|
||||
pub fn encrypt_password(config: &Table) -> Option<String> {
|
||||
let user = config
|
||||
.get("user")
|
||||
.expect("the `postgres.user` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.user` field should be a str");
|
||||
let database = config
|
||||
.get("database")
|
||||
.expect("the `postgres.database` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.database` field should be a str");
|
||||
let password = config
|
||||
.get("password")
|
||||
.expect("the `postgres.password` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.password` field should be a str");
|
||||
let key = format!("{user}@{database}");
|
||||
crate::crypto::encrypt(key.as_bytes(), password.as_bytes())
|
||||
.ok()
|
||||
.map(base64::encode)
|
||||
}
|
||||
|
||||
/// Connects lazily to the database according to the config.
|
||||
pub fn connect_lazy(config: &Table) -> Result<Self, Error> {
|
||||
let host = config
|
||||
|
@ -43,17 +66,24 @@ impl ConnectionPool {
|
|||
.expect("the `postgres.user` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.user` field should be a str");
|
||||
let password = config
|
||||
.get("password")
|
||||
.expect("the `postgres.password` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.password` field should be a str");
|
||||
let database = config
|
||||
.get("database")
|
||||
.expect("the `postgres.database` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.database` field should be a str");
|
||||
let connection_string = format!("postgres://{user}:{password}@{host}:{port}/{database}",);
|
||||
let mut password = config
|
||||
.get("password")
|
||||
.expect("the `postgres.password` field is missing")
|
||||
.as_str()
|
||||
.expect("the `postgres.password` field should be a str");
|
||||
if let Ok(data) = base64::decode(password) {
|
||||
let key = format!("{user}@{database}");
|
||||
if let Ok(plaintext) = crate::crypto::decrypt(key.as_bytes(), &data) {
|
||||
password = plaintext.leak();
|
||||
}
|
||||
}
|
||||
|
||||
let connection_string = format!("postgres://{user}:{password}@{host}:{port}/{database}");
|
||||
let max_connections = config
|
||||
.get("max-connections")
|
||||
.and_then(|t| t.as_integer())
|
||||
|
|
|
@ -39,17 +39,16 @@ pub trait Schema: 'static + Send + Sync + Model {
|
|||
/// Returns the model namespace.
|
||||
#[inline]
|
||||
fn model_namespace() -> &'static str {
|
||||
let namespace = [*NAMESPACE_PREFIX, Self::TYPE_NAME].join(":");
|
||||
Box::leak(namespace.into_boxed_str())
|
||||
[*NAMESPACE_PREFIX, Self::TYPE_NAME].join(":").leak()
|
||||
}
|
||||
|
||||
/// Returns the table name.
|
||||
#[inline]
|
||||
fn table_name() -> &'static str {
|
||||
let table_name = [*NAMESPACE_PREFIX, Self::TYPE_NAME]
|
||||
[*NAMESPACE_PREFIX, Self::TYPE_NAME]
|
||||
.join("_")
|
||||
.replace(':', "_");
|
||||
Box::leak(table_name.into_boxed_str())
|
||||
.replace(':', "_")
|
||||
.leak()
|
||||
}
|
||||
|
||||
/// Gets a column for the field.
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#![feature(async_fn_in_trait)]
|
||||
#![feature(iter_intersperse)]
|
||||
#![feature(once_cell)]
|
||||
#![feature(string_leak)]
|
||||
|
||||
mod application;
|
||||
mod authentication;
|
||||
|
|
|
@ -20,11 +20,11 @@ pub struct Context {
|
|||
|
||||
impl Context {
|
||||
/// Creates a new instance.
|
||||
pub fn new() -> Self {
|
||||
pub fn new(request_id: Uuid) -> Self {
|
||||
Self {
|
||||
start_time: Instant::now(),
|
||||
request_path: String::new(),
|
||||
request_id: Uuid::new_v4(),
|
||||
request_id,
|
||||
trace_id: Uuid::nil(),
|
||||
session_id: None,
|
||||
}
|
||||
|
@ -84,9 +84,3 @@ impl Context {
|
|||
self.session_id.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Context {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,16 +39,20 @@ pub trait RequestContext {
|
|||
|
||||
/// Creates a new request context.
|
||||
fn new_context(&self) -> Context {
|
||||
let request_id = self
|
||||
.get_header("x-request-id")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(Uuid::new_v4());
|
||||
let trace_context = self.trace_context();
|
||||
let trace_id = trace_context.map_or(Uuid::nil(), |t| Uuid::from_u128(t.trace_id()));
|
||||
let query = self.parse_query().unwrap_or_default();
|
||||
let session_id = Validation::parse_string(query.get("session_id")).or_else(|| {
|
||||
self.get_header("Session-Id").and_then(|header| {
|
||||
self.get_header("session-id").and_then(|header| {
|
||||
// Session IDs have the form: SID:type:realm:identifier[-thread][:count]
|
||||
header.split(':').nth(3).map(|s| s.to_string())
|
||||
})
|
||||
});
|
||||
let mut ctx = Context::new();
|
||||
let mut ctx = Context::new(request_id);
|
||||
ctx.set_trace_id(trace_id);
|
||||
ctx.set_session_id(session_id);
|
||||
ctx
|
||||
|
@ -136,7 +140,7 @@ pub trait RequestContext {
|
|||
if !validation.is_success() {
|
||||
return Err(validation);
|
||||
}
|
||||
} else if let Some(authorization) = self.get_header("Authorization") {
|
||||
} else if let Some(authorization) = self.get_header("authorization") {
|
||||
if let Some((service_name, token)) = authorization.split_once(' ') {
|
||||
authentication.set_service_name(service_name);
|
||||
if let Some((access_key_id, signature)) = token.split_once(':') {
|
||||
|
@ -152,16 +156,16 @@ pub trait RequestContext {
|
|||
return Err(validation);
|
||||
}
|
||||
}
|
||||
if let Some(content_md5) = self.get_header("Content-MD5") {
|
||||
if let Some(content_md5) = self.get_header("content-md5") {
|
||||
authentication.set_content_md5(content_md5.to_string());
|
||||
}
|
||||
if let Some(date) = self.get_header("Date") {
|
||||
if let Some(date) = self.get_header("date") {
|
||||
match DateTime::parse_utc_str(date) {
|
||||
Ok(date) => {
|
||||
let current = DateTime::now();
|
||||
let max_tolerance = Duration::from_secs(900);
|
||||
if date >= current - max_tolerance && date <= current + max_tolerance {
|
||||
authentication.set_date_header("Date".to_string(), date);
|
||||
authentication.set_date_header("date".to_string(), date);
|
||||
} else {
|
||||
validation.record_fail("date", "untrusted date");
|
||||
}
|
||||
|
@ -172,7 +176,7 @@ pub trait RequestContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
authentication.set_content_type(self.get_header("Content-Type").map(|t| t.to_string()));
|
||||
authentication.set_content_type(self.get_header("content-type").map(|t| t.to_string()));
|
||||
authentication.set_resource(self.request_path().to_string(), None);
|
||||
Ok(authentication)
|
||||
}
|
||||
|
|
|
@ -343,12 +343,12 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
|
|||
HeaderValue::from_str(content_type.as_str()).unwrap(),
|
||||
)
|
||||
.body(Full::from(bytes))
|
||||
.unwrap(),
|
||||
.unwrap_or_default(),
|
||||
Err(err) => http::Response::builder()
|
||||
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.body(Full::from(err.to_string()))
|
||||
.unwrap(),
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
None => match serde_json::to_vec(&response) {
|
||||
Ok(bytes) => {
|
||||
|
@ -361,13 +361,13 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
|
|||
.status(response.status_code)
|
||||
.header(header::CONTENT_TYPE, HeaderValue::from_static(content_type))
|
||||
.body(Full::from(bytes))
|
||||
.unwrap()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(err) => http::Response::builder()
|
||||
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.body(Full::from(err.to_string()))
|
||||
.unwrap(),
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
};
|
||||
let trace_context = match response.trace_context {
|
||||
|
@ -384,6 +384,14 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
|
|||
HeaderName::from_static("server-timing"),
|
||||
HeaderValue::from_str(response.server_timing.value().as_str()).unwrap(),
|
||||
);
|
||||
|
||||
let request_id = response.request_id;
|
||||
if !request_id.is_nil() {
|
||||
res.headers_mut().insert(
|
||||
HeaderName::from_static("x-request-id"),
|
||||
HeaderValue::from_str(request_id.to_string().as_str()).unwrap(),
|
||||
);
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "zino-derive"
|
||||
description = "Derived traits for zino."
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
rust-version = "1.68"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
|
@ -20,4 +20,4 @@ syn = { version = "1.0.107", features = ["full", "extra-traits"] }
|
|||
|
||||
[dependencies.zino-core]
|
||||
path = "../zino-core"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "zino-model"
|
||||
description = "Model types for zino."
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
rust-version = "1.68"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
|
@ -15,8 +15,8 @@ serde = { version = "1.0.152", features = ["derive"] }
|
|||
|
||||
[dependencies.zino-core]
|
||||
path = "../zino-core"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
|
||||
[dependencies.zino-derive]
|
||||
path = "../zino-derive"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "zino"
|
||||
description = "A minimal web framework."
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
rust-version = "1.68"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
|
@ -29,11 +29,11 @@ serde_urlencoded = { version = "0.7.1" }
|
|||
tokio = { version = "1.23.0", features = ["rt-multi-thread", "sync"], optional = true }
|
||||
tokio-stream = { version = "0.1.11", features = ["sync"], optional = true }
|
||||
toml = { version = "0.5.10" }
|
||||
tower = { version = "0.4.13", optional = true }
|
||||
tower = { version = "0.4.13", features = ["timeout"], optional = true }
|
||||
tower-http = { version = "0.3.5", features = ["full"], optional = true }
|
||||
tracing = { version = "0.1.37" }
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json", "local-time"] }
|
||||
|
||||
[dependencies.zino-core]
|
||||
path = "../zino-core"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
|
@ -1,7 +1,9 @@
|
|||
use axum::{
|
||||
body::{Bytes, Full},
|
||||
error_handling::HandleErrorLayer,
|
||||
extract::{rejection::LengthLimitError, DefaultBodyLimit},
|
||||
http::{self, StatusCode},
|
||||
middleware, routing, Router, Server,
|
||||
middleware, routing, BoxError, Router, Server,
|
||||
};
|
||||
use futures::future;
|
||||
use std::{
|
||||
|
@ -11,10 +13,13 @@ use std::{
|
|||
net::SocketAddr,
|
||||
path::Path,
|
||||
sync::{Arc, LazyLock},
|
||||
time::Instant,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::runtime::Builder;
|
||||
use tower::ServiceBuilder;
|
||||
use tower::{
|
||||
timeout::{error::Elapsed, TimeoutLayer},
|
||||
ServiceBuilder,
|
||||
};
|
||||
use tower_http::{
|
||||
add_extension::AddExtensionLayer,
|
||||
compression::CompressionLayer,
|
||||
|
@ -121,8 +126,21 @@ impl Application for AxumCluster {
|
|||
.layer(middleware::from_fn(
|
||||
crate::middleware::axum_context::request_context,
|
||||
))
|
||||
.layer(DefaultBodyLimit::disable())
|
||||
.layer(AddExtensionLayer::new(state))
|
||||
.layer(CompressionLayer::new()),
|
||||
.layer(CompressionLayer::new())
|
||||
.layer(HandleErrorLayer::new(|err: BoxError| async move {
|
||||
let status_code = if err.is::<Elapsed>() {
|
||||
StatusCode::REQUEST_TIMEOUT
|
||||
} else if err.is::<LengthLimitError>() {
|
||||
StatusCode::PAYLOAD_TOO_LARGE
|
||||
} else {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
};
|
||||
let res = Response::new(status_code);
|
||||
Ok::<http::Response<Full<Bytes>>, Infallible>(res.into())
|
||||
}))
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(10))),
|
||||
);
|
||||
|
||||
let addr = listener
|
||||
|
|
|
@ -18,7 +18,7 @@ pub(crate) async fn websocket_handler(
|
|||
let source = subscription.source();
|
||||
let topic = subscription.topic();
|
||||
while let Some(Ok(Message::Text(message))) = socket.recv().await {
|
||||
let data = Box::leak(message.into_boxed_str());
|
||||
let data = message.leak();
|
||||
match serde_json::from_str::<'_, CloudEvent>(data) {
|
||||
Ok(event) => {
|
||||
let event_session_id = event.session_id();
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#![feature(async_fn_in_trait)]
|
||||
#![feature(once_cell)]
|
||||
#![feature(result_option_inspect)]
|
||||
#![feature(string_leak)]
|
||||
|
||||
mod channel;
|
||||
mod cluster;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::AxumExtractor;
|
||||
use axum::{
|
||||
body::{Body, BoxBody},
|
||||
http::{Request, Response, StatusCode},
|
||||
|
@ -10,7 +9,7 @@ pub(crate) async fn request_context(
|
|||
req: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> Result<Response<BoxBody>, StatusCode> {
|
||||
let mut req_extractor = AxumExtractor(req);
|
||||
let mut req_extractor = crate::AxumExtractor(req);
|
||||
let ext = match req_extractor.get_context() {
|
||||
Some(_) => None,
|
||||
None => {
|
||||
|
|
|
@ -22,7 +22,7 @@ pub(crate) static TRACING_MIDDLEWARE: LazyLock<
|
|||
let mut env_filter = if is_dev {
|
||||
"sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"
|
||||
} else {
|
||||
"warn,tower_http=info,zino=info,zino_core=info"
|
||||
"sqlx=warn,tower_http=info,zino=info,zino_core=info"
|
||||
};
|
||||
let mut display_target = is_dev;
|
||||
let mut display_filename = false;
|
||||
|
|
|
@ -16,7 +16,7 @@ use toml::value::Table;
|
|||
use zino_core::{CloudEvent, Context, Map, Rejection, RequestContext, State, Validation};
|
||||
|
||||
/// An HTTP request extractor for `axum`.
|
||||
pub struct AxumExtractor<T>(pub T);
|
||||
pub struct AxumExtractor<T>(pub(crate) T);
|
||||
|
||||
impl<T> Deref for AxumExtractor<T> {
|
||||
type Target = T;
|
||||
|
@ -70,7 +70,7 @@ impl RequestContext for AxumExtractor<Request<Body>> {
|
|||
|
||||
async fn parse_body(&mut self) -> Result<Map, Validation> {
|
||||
let form_urlencoded = self
|
||||
.get_header("Content-Type")
|
||||
.get_header("content-type")
|
||||
.unwrap_or("application/x-www-form-urlencoded")
|
||||
.starts_with("application/x-www-form-urlencoded");
|
||||
let body = self.body_mut();
|
||||
|
@ -109,14 +109,7 @@ impl FromRequest<(), Body> for AxumExtractor<Request<Body>> {
|
|||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(req: Request<Body>, _state: &()) -> Result<Self, Self::Rejection> {
|
||||
let (parts, body) = req.into_parts();
|
||||
let mut request = Request::new(body);
|
||||
*request.method_mut() = parts.method;
|
||||
*request.uri_mut() = parts.uri;
|
||||
*request.version_mut() = parts.version;
|
||||
*request.headers_mut() = parts.headers;
|
||||
*request.extensions_mut() = parts.extensions;
|
||||
Ok(AxumExtractor(request))
|
||||
Ok(AxumExtractor(req))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue