Fix crypto

This commit is contained in:
photino 2023-01-06 04:55:37 +08:00
parent ff6f05aed6
commit af815b0ea7
22 changed files with 174 additions and 104 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "axum-app"
version = "0.2.1"
version = "0.2.2"
rust-version = "1.68"
edition = "2021"
publish = false

View File

@ -1,6 +1,6 @@
name = "data-cube"
version = "1.0.2"
version = "0.2.2"
[main]
host = "127.0.0.1"
@ -22,6 +22,9 @@ namespace = "dc"
[[postgres]]
host = "localhost"
port = 5432
user = "postgres"
password = "postgres"
database = "data_cube"
user = "postgres"
password = "QAx01wnh1i5ER713zfHmZi6dIUYn/Iq9ag+iUGtvKzEFJFYW"
[tracing]
filter = "sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"

View File

@ -1,6 +1,6 @@
name = "data-cube"
version = "1.0.2"
version = "0.2.2"
[main]
host = "127.0.0.1"
@ -22,9 +22,9 @@ namespace = "dc"
[[postgres]]
host = "localhost"
port = 5432
user = "postgres"
password = "postgres"
database = "data_cube"
user = "postgres"
password = "G76hTg8T5Aa+SZQFc+0QnsRLo1UOjqpkp/jUQ+lySc8QCt4B"
[tracing]
filter = "warn,zino=info,zino_core=info"
filter = "sqlx=warn,tower_http=info,zino=info,zino_core=info"

View File

@ -1,7 +1,7 @@
[package]
name = "zino-core"
description = "Core types and traits for zino."
version = "0.2.1"
version = "0.2.2"
rust-version = "1.68"
edition = "2021"
license = "MIT"
@ -25,7 +25,6 @@ http-types = { version = "2.12.0" }
rand = { version = "0.8.5" }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = { version = "1.0.91" }
sha-1 = { version = "0.10.1" }
sha2 = { version = "0.10.6" }
sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "postgres", "uuid", "time", "json"] }
time = { version = "0.3.17", features = ["local-offset", "parsing", "serde"] }

View File

@ -1,6 +1,8 @@
use crate::{DateTime, Map, Validation};
use hmac::{Hmac, Mac};
use sha1::Sha1;
use hmac::{
digest::{FixedOutput, KeyInit, MacMarker, Update},
Mac,
};
use std::time::Duration;
mod access_key;
@ -11,7 +13,7 @@ pub use access_key::{AccessKeyId, SecretAccessKey};
pub(crate) use security_token::ParseTokenError;
pub use security_token::SecurityToken;
/// HTTP signature using RFC 2104 HMAC-SHA1.
/// HTTP signature using HMAC.
pub struct Authentication {
/// Service name.
service_name: String,
@ -49,7 +51,7 @@ impl Authentication {
accept: None,
content_md5: None,
content_type: None,
date_header: ("Date".to_string(), DateTime::now()),
date_header: ("date".to_string(), DateTime::now()),
expires: None,
headers: Vec::new(),
resource: String::new(),
@ -74,25 +76,25 @@ impl Authentication {
self.signature = signature;
}
/// Sets the `Accept` header value.
/// Sets the `accept` header value.
#[inline]
pub fn set_accept(&mut self, accept: impl Into<Option<String>>) {
self.accept = accept.into();
}
/// Sets the `Content-MD5` header value.
/// Sets the `content-md5` header value.
#[inline]
pub fn set_content_md5(&mut self, content_md5: String) {
self.content_md5 = Some(content_md5);
}
/// Sets the `Content-Type` header value.
/// Sets the `content-type` header value.
#[inline]
pub fn set_content_type(&mut self, content_type: impl Into<Option<String>>) {
self.content_type = content_type.into();
}
/// Sets the `Date` header value.
/// Sets the `date` header value.
#[inline]
pub fn set_date_header(&mut self, header_name: String, date: DateTime) {
self.date_header = (header_name, date);
@ -164,7 +166,7 @@ impl Authentication {
self.signature.as_str()
}
/// Returns an `Authorization` header value.
/// Returns an `authorization` header value.
#[inline]
pub fn authorization(&self) -> String {
let service_name = self.service_name();
@ -214,7 +216,7 @@ impl Authentication {
} else {
// Date
let date_header = &self.date_header;
let date = if date_header.0.eq_ignore_ascii_case("Date") {
let date = if date_header.0.eq_ignore_ascii_case("date") {
date_header.1.to_utc_string()
} else {
"".to_string()
@ -238,16 +240,22 @@ impl Authentication {
}
/// Generates a signature with the secret access key.
pub fn sign_with(&self, secret_access_key: SecretAccessKey) -> String {
pub fn sign_with<H>(&self, secret_access_key: SecretAccessKey) -> String
where
H: FixedOutput + KeyInit + MacMarker + Update,
{
let string_to_sign = self.string_to_sign();
let mut mac = Hmac::<Sha1>::new_from_slice(secret_access_key.as_ref())
.expect("HMAC can take key of any size");
let mut mac =
H::new_from_slice(secret_access_key.as_ref()).expect("HMAC can take key of any size");
mac.update(string_to_sign.as_ref());
base64::encode(mac.finalize().into_bytes())
}
/// Validates the signature using the secret access key.
pub fn validate_with(&self, secret_access_key: SecretAccessKey) -> Validation {
pub fn validate_with<H>(&self, secret_access_key: SecretAccessKey) -> Validation
where
H: FixedOutput + KeyInit + MacMarker + Update,
{
let mut validation = Validation::new();
let current = DateTime::now();
let date = self.date_header.1;
@ -264,7 +272,7 @@ impl Authentication {
}
let signature = self.signature();
if signature.is_empty() || self.sign_with(secret_access_key) == signature {
if signature.is_empty() || self.sign_with::<H>(secret_access_key) == signature {
validation.record_fail("signature", "invalid signature");
}
validation

View File

@ -49,9 +49,11 @@ impl SecurityToken {
let key = key.as_ref();
let grantor_id = id.into();
let timestamp = expires.timestamp();
let assignee_id = base64::encode(crypto::encrypt(key, grantor_id.as_ref())).into();
let grantor_id_cipher = crypto::encrypt(key, grantor_id.as_ref()).unwrap_or_default();
let assignee_id = base64::encode(grantor_id_cipher).into();
let authorization = format!("{assignee_id}:{timestamp}");
let token = base64::encode(crypto::encrypt(key, authorization.as_ref()));
let authorization_cipher = crypto::encrypt(key, authorization.as_ref()).unwrap_or_default();
let token = base64::encode(authorization_cipher);
Self {
grantor_id,
assignee_id,
@ -65,13 +67,17 @@ impl SecurityToken {
use ParseTokenError::*;
match base64::decode(&token) {
Ok(data) => {
let authorization = crypto::decrypt(key, &data);
let authorization = crypto::decrypt(key, &data)
.map_err(|_| DecodeError("fail to decrypt authorization".to_string()))?;
if let Some((assignee_id, timestamp)) = authorization.split_once(':') {
match timestamp.parse() {
Ok(secs) => {
if DateTime::now().timestamp() <= secs {
let expires = DateTime::from_timestamp(secs);
let grantor_id = crypto::decrypt(key, assignee_id.as_ref());
let grantor_id = crypto::decrypt(key, assignee_id.as_ref())
.map_err(|_| {
DecodeError("fail to decrypt grantor id".to_string())
})?;
Ok(Self {
grantor_id: grantor_id.into(),
assignee_id: assignee_id.into(),

View File

@ -2,35 +2,42 @@
use aes_gcm_siv::{
aead::{generic_array::GenericArray, Aead},
Aes256GcmSiv, KeyInit,
Aes256GcmSiv, Error, KeyInit, Nonce,
};
use rand::Rng;
/// Encrypts the plaintext using AES-GCM-SIV.
pub(crate) fn encrypt(key: &[u8], plaintext: &[u8]) -> Vec<u8> {
let key = GenericArray::from_slice(key);
let cipher = Aes256GcmSiv::new(key);
pub(crate) fn encrypt(key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, Error> {
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
let key_padding = [key, &[0u8; KEY_SIZE]].concat();
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key_padding[0..KEY_SIZE]));
let mut rng = rand::thread_rng();
let mut bytes = [0u8; 12];
let mut bytes = [0u8; NONCE_SIZE];
rng.fill(&mut bytes);
let nonce = GenericArray::from_slice(&bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.expect("encryption failure");
bytes.copy_from_slice(&ciphertext);
bytes.to_vec()
let nonce = Nonce::from_slice(&bytes);
let mut ciphertext = cipher.encrypt(nonce, plaintext)?;
ciphertext.extend_from_slice(&bytes);
Ok(ciphertext)
}
/// Decrypts the data using AES-GCM-SIV.
pub(crate) fn decrypt(key: &[u8], data: &[u8]) -> String {
let key = GenericArray::from_slice(key);
let cipher = Aes256GcmSiv::new(key);
let (bytes, ciphertext) = data.split_at(92);
pub(crate) fn decrypt(key: &[u8], data: &[u8]) -> Result<String, Error> {
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
if data.len() <= NONCE_SIZE {
return Err(Error);
}
let key_padding = [key, &[0u8; KEY_SIZE]].concat();
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key_padding[0..KEY_SIZE]));
let (ciphertext, bytes) = data.split_at(data.len() - NONCE_SIZE);
let nonce = GenericArray::from_slice(bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.expect("decryption failure");
String::from_utf8_lossy(&plaintext).to_string()
let plaintext = cipher.decrypt(nonce, ciphertext)?;
Ok(String::from_utf8_lossy(&plaintext).to_string())
}

View File

@ -26,6 +26,29 @@ pub struct ConnectionPool {
}
impl ConnectionPool {
/// Encrypts the database password in the config.
pub fn encrypt_password(config: &Table) -> Option<String> {
let user = config
.get("user")
.expect("the `postgres.user` field is missing")
.as_str()
.expect("the `postgres.user` field should be a str");
let database = config
.get("database")
.expect("the `postgres.database` field is missing")
.as_str()
.expect("the `postgres.database` field should be a str");
let password = config
.get("password")
.expect("the `postgres.password` field is missing")
.as_str()
.expect("the `postgres.password` field should be a str");
let key = format!("{user}@{database}");
crate::crypto::encrypt(key.as_bytes(), password.as_bytes())
.ok()
.map(base64::encode)
}
/// Connects lazily to the database according to the config.
pub fn connect_lazy(config: &Table) -> Result<Self, Error> {
let host = config
@ -43,17 +66,24 @@ impl ConnectionPool {
.expect("the `postgres.user` field is missing")
.as_str()
.expect("the `postgres.user` field should be a str");
let password = config
.get("password")
.expect("the `postgres.password` field is missing")
.as_str()
.expect("the `postgres.password` field should be a str");
let database = config
.get("database")
.expect("the `postgres.database` field is missing")
.as_str()
.expect("the `postgres.database` field should be a str");
let connection_string = format!("postgres://{user}:{password}@{host}:{port}/{database}",);
let mut password = config
.get("password")
.expect("the `postgres.password` field is missing")
.as_str()
.expect("the `postgres.password` field should be a str");
if let Ok(data) = base64::decode(password) {
let key = format!("{user}@{database}");
if let Ok(plaintext) = crate::crypto::decrypt(key.as_bytes(), &data) {
password = plaintext.leak();
}
}
let connection_string = format!("postgres://{user}:{password}@{host}:{port}/{database}");
let max_connections = config
.get("max-connections")
.and_then(|t| t.as_integer())

View File

@ -39,17 +39,16 @@ pub trait Schema: 'static + Send + Sync + Model {
/// Returns the model namespace.
#[inline]
fn model_namespace() -> &'static str {
let namespace = [*NAMESPACE_PREFIX, Self::TYPE_NAME].join(":");
Box::leak(namespace.into_boxed_str())
[*NAMESPACE_PREFIX, Self::TYPE_NAME].join(":").leak()
}
/// Returns the table name.
#[inline]
fn table_name() -> &'static str {
let table_name = [*NAMESPACE_PREFIX, Self::TYPE_NAME]
[*NAMESPACE_PREFIX, Self::TYPE_NAME]
.join("_")
.replace(':', "_");
Box::leak(table_name.into_boxed_str())
.replace(':', "_")
.leak()
}
/// Gets a column for the field.

View File

@ -3,6 +3,7 @@
#![feature(async_fn_in_trait)]
#![feature(iter_intersperse)]
#![feature(once_cell)]
#![feature(string_leak)]
mod application;
mod authentication;

View File

@ -20,11 +20,11 @@ pub struct Context {
impl Context {
/// Creates a new instance.
pub fn new() -> Self {
pub fn new(request_id: Uuid) -> Self {
Self {
start_time: Instant::now(),
request_path: String::new(),
request_id: Uuid::new_v4(),
request_id,
trace_id: Uuid::nil(),
session_id: None,
}
@ -84,9 +84,3 @@ impl Context {
self.session_id.as_deref()
}
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}

View File

@ -39,16 +39,20 @@ pub trait RequestContext {
/// Creates a new request context.
fn new_context(&self) -> Context {
let request_id = self
.get_header("x-request-id")
.and_then(|s| s.parse().ok())
.unwrap_or(Uuid::new_v4());
let trace_context = self.trace_context();
let trace_id = trace_context.map_or(Uuid::nil(), |t| Uuid::from_u128(t.trace_id()));
let query = self.parse_query().unwrap_or_default();
let session_id = Validation::parse_string(query.get("session_id")).or_else(|| {
self.get_header("Session-Id").and_then(|header| {
self.get_header("session-id").and_then(|header| {
// Session IDs have the form: SID:type:realm:identifier[-thread][:count]
header.split(':').nth(3).map(|s| s.to_string())
})
});
let mut ctx = Context::new();
let mut ctx = Context::new(request_id);
ctx.set_trace_id(trace_id);
ctx.set_session_id(session_id);
ctx
@ -136,7 +140,7 @@ pub trait RequestContext {
if !validation.is_success() {
return Err(validation);
}
} else if let Some(authorization) = self.get_header("Authorization") {
} else if let Some(authorization) = self.get_header("authorization") {
if let Some((service_name, token)) = authorization.split_once(' ') {
authentication.set_service_name(service_name);
if let Some((access_key_id, signature)) = token.split_once(':') {
@ -152,16 +156,16 @@ pub trait RequestContext {
return Err(validation);
}
}
if let Some(content_md5) = self.get_header("Content-MD5") {
if let Some(content_md5) = self.get_header("content-md5") {
authentication.set_content_md5(content_md5.to_string());
}
if let Some(date) = self.get_header("Date") {
if let Some(date) = self.get_header("date") {
match DateTime::parse_utc_str(date) {
Ok(date) => {
let current = DateTime::now();
let max_tolerance = Duration::from_secs(900);
if date >= current - max_tolerance && date <= current + max_tolerance {
authentication.set_date_header("Date".to_string(), date);
authentication.set_date_header("date".to_string(), date);
} else {
validation.record_fail("date", "untrusted date");
}
@ -172,7 +176,7 @@ pub trait RequestContext {
}
}
}
authentication.set_content_type(self.get_header("Content-Type").map(|t| t.to_string()));
authentication.set_content_type(self.get_header("content-type").map(|t| t.to_string()));
authentication.set_resource(self.request_path().to_string(), None);
Ok(authentication)
}

View File

@ -343,12 +343,12 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
HeaderValue::from_str(content_type.as_str()).unwrap(),
)
.body(Full::from(bytes))
.unwrap(),
.unwrap_or_default(),
Err(err) => http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/plain")
.body(Full::from(err.to_string()))
.unwrap(),
.unwrap_or_default(),
},
None => match serde_json::to_vec(&response) {
Ok(bytes) => {
@ -361,13 +361,13 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
.status(response.status_code)
.header(header::CONTENT_TYPE, HeaderValue::from_static(content_type))
.body(Full::from(bytes))
.unwrap()
.unwrap_or_default()
}
Err(err) => http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/plain")
.body(Full::from(err.to_string()))
.unwrap(),
.unwrap_or_default(),
},
};
let trace_context = match response.trace_context {
@ -384,6 +384,14 @@ impl From<Response<http::StatusCode>> for http::Response<Full<Bytes>> {
HeaderName::from_static("server-timing"),
HeaderValue::from_str(response.server_timing.value().as_str()).unwrap(),
);
let request_id = response.request_id;
if !request_id.is_nil() {
res.headers_mut().insert(
HeaderName::from_static("x-request-id"),
HeaderValue::from_str(request_id.to_string().as_str()).unwrap(),
);
}
res
}
}

View File

@ -1,7 +1,7 @@
[package]
name = "zino-derive"
description = "Derived traits for zino."
version = "0.2.1"
version = "0.2.2"
rust-version = "1.68"
edition = "2021"
license = "MIT"
@ -20,4 +20,4 @@ syn = { version = "1.0.107", features = ["full", "extra-traits"] }
[dependencies.zino-core]
path = "../zino-core"
version = "0.2.1"
version = "0.2.2"

View File

@ -1,7 +1,7 @@
[package]
name = "zino-model"
description = "Model types for zino."
version = "0.2.1"
version = "0.2.2"
rust-version = "1.68"
edition = "2021"
license = "MIT"
@ -15,8 +15,8 @@ serde = { version = "1.0.152", features = ["derive"] }
[dependencies.zino-core]
path = "../zino-core"
version = "0.2.1"
version = "0.2.2"
[dependencies.zino-derive]
path = "../zino-derive"
version = "0.2.1"
version = "0.2.2"

View File

@ -1,7 +1,7 @@
[package]
name = "zino"
description = "A minimal web framework."
version = "0.2.1"
version = "0.2.2"
rust-version = "1.68"
edition = "2021"
license = "MIT"
@ -29,11 +29,11 @@ serde_urlencoded = { version = "0.7.1" }
tokio = { version = "1.23.0", features = ["rt-multi-thread", "sync"], optional = true }
tokio-stream = { version = "0.1.11", features = ["sync"], optional = true }
toml = { version = "0.5.10" }
tower = { version = "0.4.13", optional = true }
tower = { version = "0.4.13", features = ["timeout"], optional = true }
tower-http = { version = "0.3.5", features = ["full"], optional = true }
tracing = { version = "0.1.37" }
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json", "local-time"] }
[dependencies.zino-core]
path = "../zino-core"
version = "0.2.1"
version = "0.2.2"

View File

@ -1,7 +1,9 @@
use axum::{
body::{Bytes, Full},
error_handling::HandleErrorLayer,
extract::{rejection::LengthLimitError, DefaultBodyLimit},
http::{self, StatusCode},
middleware, routing, Router, Server,
middleware, routing, BoxError, Router, Server,
};
use futures::future;
use std::{
@ -11,10 +13,13 @@ use std::{
net::SocketAddr,
path::Path,
sync::{Arc, LazyLock},
time::Instant,
time::{Duration, Instant},
};
use tokio::runtime::Builder;
use tower::ServiceBuilder;
use tower::{
timeout::{error::Elapsed, TimeoutLayer},
ServiceBuilder,
};
use tower_http::{
add_extension::AddExtensionLayer,
compression::CompressionLayer,
@ -121,8 +126,21 @@ impl Application for AxumCluster {
.layer(middleware::from_fn(
crate::middleware::axum_context::request_context,
))
.layer(DefaultBodyLimit::disable())
.layer(AddExtensionLayer::new(state))
.layer(CompressionLayer::new()),
.layer(CompressionLayer::new())
.layer(HandleErrorLayer::new(|err: BoxError| async move {
let status_code = if err.is::<Elapsed>() {
StatusCode::REQUEST_TIMEOUT
} else if err.is::<LengthLimitError>() {
StatusCode::PAYLOAD_TOO_LARGE
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
let res = Response::new(status_code);
Ok::<http::Response<Full<Bytes>>, Infallible>(res.into())
}))
.layer(TimeoutLayer::new(Duration::from_secs(10))),
);
let addr = listener

View File

@ -18,7 +18,7 @@ pub(crate) async fn websocket_handler(
let source = subscription.source();
let topic = subscription.topic();
while let Some(Ok(Message::Text(message))) = socket.recv().await {
let data = Box::leak(message.into_boxed_str());
let data = message.leak();
match serde_json::from_str::<'_, CloudEvent>(data) {
Ok(event) => {
let event_session_id = event.session_id();

View File

@ -3,6 +3,7 @@
#![feature(async_fn_in_trait)]
#![feature(once_cell)]
#![feature(result_option_inspect)]
#![feature(string_leak)]
mod channel;
mod cluster;

View File

@ -1,4 +1,3 @@
use crate::AxumExtractor;
use axum::{
body::{Body, BoxBody},
http::{Request, Response, StatusCode},
@ -10,7 +9,7 @@ pub(crate) async fn request_context(
req: Request<Body>,
next: Next<Body>,
) -> Result<Response<BoxBody>, StatusCode> {
let mut req_extractor = AxumExtractor(req);
let mut req_extractor = crate::AxumExtractor(req);
let ext = match req_extractor.get_context() {
Some(_) => None,
None => {

View File

@ -22,7 +22,7 @@ pub(crate) static TRACING_MIDDLEWARE: LazyLock<
let mut env_filter = if is_dev {
"sqlx=trace,tower_http=trace,zino=trace,zino_core=trace"
} else {
"warn,tower_http=info,zino=info,zino_core=info"
"sqlx=warn,tower_http=info,zino=info,zino_core=info"
};
let mut display_target = is_dev;
let mut display_filename = false;

View File

@ -16,7 +16,7 @@ use toml::value::Table;
use zino_core::{CloudEvent, Context, Map, Rejection, RequestContext, State, Validation};
/// An HTTP request extractor for `axum`.
pub struct AxumExtractor<T>(pub T);
pub struct AxumExtractor<T>(pub(crate) T);
impl<T> Deref for AxumExtractor<T> {
type Target = T;
@ -70,7 +70,7 @@ impl RequestContext for AxumExtractor<Request<Body>> {
async fn parse_body(&mut self) -> Result<Map, Validation> {
let form_urlencoded = self
.get_header("Content-Type")
.get_header("content-type")
.unwrap_or("application/x-www-form-urlencoded")
.starts_with("application/x-www-form-urlencoded");
let body = self.body_mut();
@ -109,14 +109,7 @@ impl FromRequest<(), Body> for AxumExtractor<Request<Body>> {
type Rejection = Infallible;
async fn from_request(req: Request<Body>, _state: &()) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let mut request = Request::new(body);
*request.method_mut() = parts.method;
*request.uri_mut() = parts.uri;
*request.version_mut() = parts.version;
*request.headers_mut() = parts.headers;
*request.extensions_mut() = parts.extensions;
Ok(AxumExtractor(request))
Ok(AxumExtractor(req))
}
}