Implement `IntoResponse` for `MultipartError` (#1861)

This commit is contained in:
David Pedersen 2023-03-21 09:24:06 +01:00 committed by GitHub
parent 8e1eb8979f
commit 03e8bc77f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 5 deletions

View File

@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- None.
- **added:** Implement `IntoResponse` for `MultipartError` ([#1861])
# 0.7.1 (13. March, 2023)

View File

@ -39,6 +39,7 @@ axum = { path = "../axum", version = "0.6.9", default-features = false }
bytes = "1.1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2"
http-body = "0.4.4"
mime = "0.3"
pin-project-lite = "0.2"
tokio = "1.19"

View File

@ -12,9 +12,10 @@ use axum::{
use futures_util::stream::Stream;
use http::{
header::{HeaderMap, CONTENT_TYPE},
Request,
Request, StatusCode,
};
use std::{
error::Error,
fmt,
pin::Pin,
task::{Context, Poll},
@ -246,6 +247,57 @@ impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}
/// Get the response body text used for this rejection.
pub fn body_text(&self) -> String {
self.source.to_string()
}
/// Get the status code used for this rejection.
pub fn status(&self) -> http::StatusCode {
status_code_from_multer_error(&self.source)
}
}
fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
match err {
multer::Error::UnknownField { .. }
| multer::Error::IncompleteFieldData { .. }
| multer::Error::IncompleteHeaders
| multer::Error::ReadHeaderFailed(..)
| multer::Error::DecodeHeaderName { .. }
| multer::Error::DecodeContentType(..)
| multer::Error::NoBoundary
| multer::Error::DecodeHeaderValue { .. }
| multer::Error::NoMultipart
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
StatusCode::PAYLOAD_TOO_LARGE
}
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return status_code_from_multer_error(err);
}
if err
.downcast_ref::<axum::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;
}
StatusCode::INTERNAL_SERVER_ERROR
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
impl IntoResponse for MultipartError {
fn into_response(self) -> Response {
(self.status(), self.body_text()).into_response()
}
}
impl fmt::Display for MultipartError {
@ -357,7 +409,9 @@ impl std::error::Error for InvalidBoundary {}
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{body::Body, response::IntoResponse, routing::post, Router};
use axum::{
body::Body, extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router,
};
#[tokio::test]
async fn content_type_with_encoding() {
@ -395,4 +449,28 @@ mod tests {
async fn handler(_: Multipart) {}
let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
}
#[tokio::test]
async fn body_too_large() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(())
}
let app = Router::new()
.route("/", post(handle))
.layer(DefaultBodyLimit::max(BYTES.len() - 1));
let client = TestClient::new(app);
let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
let res = client.post("/").multipart(form).send().await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}

View File

@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **added:** Implement `IntoResponse` for `MultipartError` ([#1861])
[#1861]: https://github.com/tokio-rs/axum/pull/1861
# 0.6.11 (13. March, 2023)

View File

@ -6,10 +6,12 @@ use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use axum_core::RequestExt;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
use http::{Request, StatusCode};
use std::error::Error;
use std::{
fmt,
pin::Pin,
@ -209,6 +211,51 @@ impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}
/// Get the response body text used for this rejection.
pub fn body_text(&self) -> String {
self.source.to_string()
}
/// Get the status code used for this rejection.
pub fn status(&self) -> http::StatusCode {
status_code_from_multer_error(&self.source)
}
}
fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
match err {
multer::Error::UnknownField { .. }
| multer::Error::IncompleteFieldData { .. }
| multer::Error::IncompleteHeaders
| multer::Error::ReadHeaderFailed(..)
| multer::Error::DecodeHeaderName { .. }
| multer::Error::DecodeContentType(..)
| multer::Error::NoBoundary
| multer::Error::DecodeHeaderValue { .. }
| multer::Error::NoMultipart
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
StatusCode::PAYLOAD_TOO_LARGE
}
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return status_code_from_multer_error(err);
}
if err
.downcast_ref::<crate::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;
}
StatusCode::INTERNAL_SERVER_ERROR
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
impl fmt::Display for MultipartError {
@ -223,6 +270,12 @@ impl std::error::Error for MultipartError {
}
}
impl IntoResponse for MultipartError {
fn into_response(self) -> Response {
(self.status(), self.body_text()).into_response()
}
}
fn parse_boundary(headers: &HeaderMap) -> Option<String> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
multer::parse_boundary(content_type).ok()
@ -247,6 +300,8 @@ define_rejection! {
#[cfg(test)]
mod tests {
use axum_core::extract::DefaultBodyLimit;
use super::*;
use crate::{body::Body, response::IntoResponse, routing::post, test_helpers::*, Router};
@ -286,4 +341,28 @@ mod tests {
async fn handler(_: Multipart) {}
let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
}
#[crate::test]
async fn body_too_large() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(())
}
let app = Router::new()
.route("/", post(handle))
.layer(DefaultBodyLimit::max(BYTES.len() - 1));
let client = TestClient::new(app);
let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
let res = client.post("/").multipart(form).send().await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}