mirror of https://github.com/smithy-lang/smithy-rs
disable stalled stream protection on empty bodies and after read complete (#3644)
## Motivation and Context <!--- Why is this change required? What problem does it solve? --> <!--- If it fixes an open issue, please link to the issue here --> * https://github.com/awslabs/aws-sdk-rust/issues/1141 * https://github.com/awslabs/aws-sdk-rust/issues/1146 * https://github.com/awslabs/aws-sdk-rust/issues/1148 ## Description * Disables stalled stream upload protection for requests with an empty/zero length body. * Disables stalled stream upload throughput checking once the request body has been read and handed off to the HTTP layer. ## Testing Additional integration tests added covering empty bodies and completed uploads. Tested SQS issue against latest runtime and can see it works now. The S3 `CopyObject` issue is related to downloads and will need a different solution. ## Checklist <!--- If a checkbox below is not applicable, then please DELETE it rather than leaving it unchecked --> - [x] I have updated `CHANGELOG.next.toml` if I made changes to the smithy-rs codegen or runtime crates - [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS SDK, generated SDK code, or SDK runtime crates ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Co-authored-by: Zelda Hessler <zhessler@amazon.com> Co-authored-by: ysaito1001 <awsaito@amazon.com>
This commit is contained in:
parent
d755bd2cd9
commit
f0ddc666d0
|
@ -36,6 +36,18 @@ references = ["aws-sdk-rust#1079"]
|
|||
meta = { "breaking" = false, "bug" = true, "tada" = false }
|
||||
author = "rcoh"
|
||||
|
||||
[[aws-sdk-rust]]
|
||||
message = "Fixes stalled upload stream protection to not apply to empty request bodies and to stop checking for violations once the request body has been read."
|
||||
references = ["aws-sdk-rust#1141", "aws-sdk-rust#1146", "aws-sdk-rust#1148"]
|
||||
meta = { "breaking" = false, "tada" = false, "bug" = true }
|
||||
authors = ["aajtodd", "Velfi"]
|
||||
|
||||
[[smithy-rs]]
|
||||
message = "Fixes stalled upload stream protection to not apply to empty request bodies and to stop checking for violations once the request body has been read."
|
||||
references = ["aws-sdk-rust#1141", "aws-sdk-rust#1146", "aws-sdk-rust#1148"]
|
||||
meta = { "breaking" = false, "tada" = false, "bug" = true }
|
||||
authors = ["aajtodd", "Velfi"]
|
||||
|
||||
[[aws-sdk-rust]]
|
||||
message = "Updating the documentation for the `app_name` method on `ConfigLoader` to indicate the order of precedence for the sources of the `AppName`."
|
||||
references = ["smithy-rs#3645"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "aws-smithy-runtime-api"
|
||||
version = "1.6.0"
|
||||
version = "1.6.1"
|
||||
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Zelda Hessler <zhessler@amazon.com>"]
|
||||
description = "Smithy runtime types."
|
||||
edition = "2021"
|
||||
|
|
|
@ -13,7 +13,11 @@
|
|||
use aws_smithy_types::config_bag::{Storable, StoreReplace};
|
||||
use std::time::Duration;
|
||||
|
||||
const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5);
|
||||
/// The default grace period for stalled stream protection.
|
||||
///
|
||||
/// When a stream stalls for longer than this grace period, the stream will
|
||||
/// return an error.
|
||||
pub const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(20);
|
||||
|
||||
/// Configuration for stalled stream protection.
|
||||
///
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "aws-smithy-runtime"
|
||||
version = "1.5.2"
|
||||
version = "1.5.3"
|
||||
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Zelda Hessler <zhessler@amazon.com>"]
|
||||
description = "The new smithy runtime crate"
|
||||
edition = "2021"
|
||||
|
|
|
@ -136,6 +136,10 @@ impl UploadThroughput {
|
|||
self.logs.lock().unwrap().push_bytes_transferred(now, bytes);
|
||||
}
|
||||
|
||||
pub(crate) fn mark_complete(&self) -> bool {
|
||||
self.logs.lock().unwrap().mark_complete()
|
||||
}
|
||||
|
||||
pub(crate) fn report(&self, now: SystemTime) -> ThroughputReport {
|
||||
self.logs.lock().unwrap().report(now)
|
||||
}
|
||||
|
@ -177,6 +181,8 @@ trait UploadReport {
|
|||
impl UploadReport for ThroughputReport {
|
||||
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) {
|
||||
let throughput = match self {
|
||||
// stream has been exhausted, stop tracking violations
|
||||
ThroughputReport::Complete => return (false, ZERO_THROUGHPUT),
|
||||
// If the report is incomplete, then we don't have enough data yet to
|
||||
// decide if minimum throughput was violated.
|
||||
ThroughputReport::Incomplete => {
|
||||
|
|
|
@ -22,6 +22,7 @@ trait DownloadReport {
|
|||
impl DownloadReport for ThroughputReport {
|
||||
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) {
|
||||
let throughput = match self {
|
||||
ThroughputReport::Complete => return (false, ZERO_THROUGHPUT),
|
||||
// If the report is incomplete, then we don't have enough data yet to
|
||||
// decide if minimum throughput was violated.
|
||||
ThroughputReport::Incomplete => {
|
||||
|
@ -175,6 +176,18 @@ where
|
|||
tracing::trace!("received data: {}", bytes.len());
|
||||
this.throughput
|
||||
.push_bytes_transferred(now, bytes.len() as u64);
|
||||
|
||||
// hyper will optimistically stop polling when end of stream is reported
|
||||
// (e.g. when content-length amount of data has been consumed) which means
|
||||
// we may never get to `Poll:Ready(None)`. Check for same condition and
|
||||
// attempt to stop checking throughput violations _now_ as we may never
|
||||
// get polled again. The caveat here is that it depends on `Body` implementations
|
||||
// implementing `is_end_stream()` correctly. Users can also disable SSP as an
|
||||
// alternative for such fringe use cases.
|
||||
if self.is_end_stream() {
|
||||
tracing::trace!("stream reported end of stream before Poll::Ready(None) reached; marking stream complete");
|
||||
self.throughput.mark_complete();
|
||||
}
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
Poll::Pending => {
|
||||
|
@ -183,7 +196,12 @@ where
|
|||
Poll::Pending
|
||||
}
|
||||
// If we've read all the data or an error occurred, then return that result.
|
||||
res => res,
|
||||
res => {
|
||||
if this.throughput.mark_complete() {
|
||||
tracing::trace!("stream completed: {:?}", res);
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
*/
|
||||
|
||||
use super::Throughput;
|
||||
use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
|
||||
use aws_smithy_runtime_api::client::stalled_stream_protection::{
|
||||
StalledStreamProtectionConfig, DEFAULT_GRACE_PERIOD,
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputBody).
|
||||
/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputDownloadBody).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MinimumThroughputBodyOptions {
|
||||
/// The minimum throughput that is acceptable.
|
||||
|
@ -69,6 +71,13 @@ impl MinimumThroughputBodyOptions {
|
|||
}
|
||||
}
|
||||
|
||||
const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput {
|
||||
bytes_read: 1,
|
||||
per_time_elapsed: Duration::from_secs(1),
|
||||
};
|
||||
|
||||
const DEFAULT_CHECK_WINDOW: Duration = Duration::from_secs(1);
|
||||
|
||||
impl Default for MinimumThroughputBodyOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
@ -87,14 +96,6 @@ pub struct MinimumThroughputBodyOptionsBuilder {
|
|||
grace_period: Option<Duration>,
|
||||
}
|
||||
|
||||
const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0);
|
||||
const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput {
|
||||
bytes_read: 1,
|
||||
per_time_elapsed: Duration::from_secs(1),
|
||||
};
|
||||
|
||||
const DEFAULT_CHECK_WINDOW: Duration = Duration::from_secs(1);
|
||||
|
||||
impl MinimumThroughputBodyOptionsBuilder {
|
||||
/// Create a new `MinimumThroughputBodyOptionsBuilder`.
|
||||
pub fn new() -> Self {
|
||||
|
|
|
@ -260,6 +260,8 @@ pub(crate) enum ThroughputReport {
|
|||
Pending,
|
||||
/// The stream transferred this amount of throughput during the time window.
|
||||
Transferred(Throughput),
|
||||
/// The stream has completed, no more data is expected.
|
||||
Complete,
|
||||
}
|
||||
|
||||
const BIN_COUNT: usize = 10;
|
||||
|
@ -285,6 +287,7 @@ pub(super) struct ThroughputLogs {
|
|||
resolution: Duration,
|
||||
current_tail: SystemTime,
|
||||
buffer: LogBuffer<BIN_COUNT>,
|
||||
stream_complete: bool,
|
||||
}
|
||||
|
||||
impl ThroughputLogs {
|
||||
|
@ -302,6 +305,7 @@ impl ThroughputLogs {
|
|||
resolution,
|
||||
current_tail: now,
|
||||
buffer: LogBuffer::new(),
|
||||
stream_complete: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -343,8 +347,24 @@ impl ThroughputLogs {
|
|||
assert!(self.current_tail >= now);
|
||||
}
|
||||
|
||||
/// Mark the stream complete indicating no more data is expected. This is an
|
||||
/// idempotent operation -- subsequent invocations of this function have no effect
|
||||
/// and return false.
|
||||
///
|
||||
/// After marking a stream complete [report](#method.report) will forever more return
|
||||
/// [ThroughputReport::Complete]
|
||||
pub(super) fn mark_complete(&mut self) -> bool {
|
||||
let prev = self.stream_complete;
|
||||
self.stream_complete = true;
|
||||
!prev
|
||||
}
|
||||
|
||||
/// Generates an overall report of the time window.
|
||||
pub(super) fn report(&mut self, now: SystemTime) -> ThroughputReport {
|
||||
if self.stream_complete {
|
||||
return ThroughputReport::Complete;
|
||||
}
|
||||
|
||||
self.catch_up(now);
|
||||
self.buffer.fill_gaps();
|
||||
|
||||
|
|
|
@ -65,6 +65,12 @@ impl Intercept for StalledStreamProtectionInterceptor {
|
|||
) -> Result<(), BoxError> {
|
||||
if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
|
||||
if sspcfg.upload_enabled() {
|
||||
if let Some(0) = context.request().body().content_length() {
|
||||
tracing::trace!(
|
||||
"skipping stalled stream protection for zero length request body"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
|
||||
let now = time_source.now();
|
||||
|
||||
|
|
|
@ -105,9 +105,13 @@ async fn download_stalls() {
|
|||
let (time, sleep) = tick_advance_time_and_sleep();
|
||||
let (server, response_sender) = channel_server();
|
||||
let op = operation(server, time.clone(), sleep);
|
||||
let barrier = Arc::new(Barrier::new(2));
|
||||
|
||||
let c = barrier.clone();
|
||||
let server = tokio::spawn(async move {
|
||||
for _ in 1..10 {
|
||||
c.wait().await;
|
||||
for i in 1..10 {
|
||||
tracing::debug!("send {i}");
|
||||
response_sender.send(NEAT_DATA).await.unwrap();
|
||||
tick!(time, Duration::from_secs(1));
|
||||
}
|
||||
|
@ -115,7 +119,10 @@ async fn download_stalls() {
|
|||
});
|
||||
|
||||
let response_body = op.invoke(()).await.expect("initial success");
|
||||
let result = tokio::spawn(eagerly_consume(response_body));
|
||||
let result = tokio::spawn(async move {
|
||||
barrier.wait().await;
|
||||
eagerly_consume(response_body).await
|
||||
});
|
||||
server.await.unwrap();
|
||||
|
||||
let err = result
|
||||
|
@ -188,6 +195,7 @@ async fn user_downloads_data_too_slowly() {
|
|||
}
|
||||
|
||||
use download_test_tools::*;
|
||||
use tokio::sync::Barrier;
|
||||
mod download_test_tools {
|
||||
use crate::stalled_stream_common::*;
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
#[macro_use]
|
||||
mod stalled_stream_common;
|
||||
|
||||
use aws_smithy_runtime_api::client::stalled_stream_protection::DEFAULT_GRACE_PERIOD;
|
||||
use stalled_stream_common::*;
|
||||
|
||||
/// Scenario: Successful upload at a rate above the minimum throughput.
|
||||
|
@ -88,7 +90,7 @@ async fn upload_too_slow() {
|
|||
async fn upload_stalls() {
|
||||
let _logs = show_test_logs();
|
||||
|
||||
let (server, time, sleep) = stalling_server();
|
||||
let (server, time, sleep) = stalling_server(None);
|
||||
let op = operation(server, time.clone(), sleep);
|
||||
|
||||
let (body, body_sender) = channel_body();
|
||||
|
@ -107,27 +109,84 @@ async fn upload_stalls() {
|
|||
expect_timeout(result.await.expect("no panics"));
|
||||
}
|
||||
|
||||
/// Scenario: All the request data is either uploaded to the server or buffered in the
|
||||
/// HTTP client, but the response doesn't start coming through within the grace period.
|
||||
/// Expected: MUST timeout after the grace period completes.
|
||||
/// Scenario: Request does not have a body. Server response doesn't start coming through
|
||||
/// until after the grace period.
|
||||
/// Expected: MUST NOT timeout.
|
||||
#[tokio::test]
|
||||
async fn complete_upload_no_response() {
|
||||
async fn empty_request_body_delayed_response() {
|
||||
let _logs = show_test_logs();
|
||||
|
||||
let (server, time, sleep) = stalling_server();
|
||||
let (server, time, sleep) = stalling_server(Some(Duration::from_secs(6)));
|
||||
let op = operation(server, time.clone(), sleep);
|
||||
|
||||
let result = tokio::spawn(async move { op.invoke(SdkBody::empty()).await });
|
||||
|
||||
let _advance = tokio::spawn(async move {
|
||||
for _ in 0..6 {
|
||||
tick!(time, Duration::from_secs(1));
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
|
||||
}
|
||||
|
||||
/// Scenario: All the request data is either uploaded to the server or buffered in the
|
||||
/// HTTP client, but the response doesn't start coming through within the grace period.
|
||||
/// Expected: MUST NOT timeout, upload throughput should only apply up until the request body has
|
||||
/// been read completely and handed off to the HTTP client.
|
||||
#[tokio::test]
|
||||
async fn complete_upload_delayed_response() {
|
||||
let _logs = show_test_logs();
|
||||
|
||||
let (server, time, sleep) = stalling_server(Some(Duration::from_secs(6)));
|
||||
let op = operation(server, time.clone(), sleep);
|
||||
|
||||
let (body, body_sender) = channel_body();
|
||||
let result = tokio::spawn(async move { op.invoke(body).await });
|
||||
|
||||
let _streamer = tokio::spawn(async move {
|
||||
info!("send data");
|
||||
body_sender.send(NEAT_DATA).await.unwrap();
|
||||
tick!(time, Duration::from_secs(1));
|
||||
info!("body send complete; dropping");
|
||||
drop(body_sender);
|
||||
time.tick(Duration::from_secs(6)).await;
|
||||
tick!(time, DEFAULT_GRACE_PERIOD);
|
||||
info!("body stream task complete");
|
||||
// advance to unblock the stalled server
|
||||
tick!(time, Duration::from_secs(2));
|
||||
});
|
||||
|
||||
expect_timeout(result.await.expect("no panics"));
|
||||
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
|
||||
}
|
||||
|
||||
/// Scenario: Upload all request data and never poll again once content-length has
|
||||
/// been reached. Hyper will stop polling once it detects end of stream so we can't rely
|
||||
/// on reaching `Poll:Ready(None)` to detect end of stream.
|
||||
///
|
||||
/// ref: https://github.com/hyperium/hyper/issues/1545
|
||||
/// ref: https://github.com/hyperium/hyper/issues/1521
|
||||
///
|
||||
/// Expected: MUST NOT timeout, upload throughput should only apply up until the request body has
|
||||
/// been read completely. Once no more data is expected we should stop checking for throughput
|
||||
/// violations.
|
||||
#[tokio::test]
|
||||
async fn complete_upload_stop_polling() {
|
||||
let _logs = show_test_logs();
|
||||
|
||||
let (server, time, sleep) = limited_read_server(NEAT_DATA.len(), Some(Duration::from_secs(7)));
|
||||
let op = operation(server, time.clone(), sleep.clone());
|
||||
|
||||
let body = SdkBody::from(NEAT_DATA);
|
||||
let result = tokio::spawn(async move { op.invoke(body).await });
|
||||
|
||||
tokio::spawn(async move {
|
||||
// advance past the grace period
|
||||
tick!(time, DEFAULT_GRACE_PERIOD + Duration::from_secs(1));
|
||||
// unblock server
|
||||
tick!(time, Duration::from_secs(2));
|
||||
});
|
||||
|
||||
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
|
||||
}
|
||||
|
||||
// Scenario: The server stops asking for data, the client maxes out its send buffer,
|
||||
|
@ -189,6 +248,8 @@ async fn user_provides_data_too_slowly() {
|
|||
|
||||
use upload_test_tools::*;
|
||||
mod upload_test_tools {
|
||||
use aws_smithy_async::rt::sleep::AsyncSleep;
|
||||
|
||||
use crate::stalled_stream_common::*;
|
||||
|
||||
pub fn successful_response() -> HttpResponse {
|
||||
|
@ -285,24 +346,43 @@ mod upload_test_tools {
|
|||
fake_server!(FakeServerConnector, fake_server, bool, advance_time)
|
||||
}
|
||||
|
||||
/// Fake server/connector that reads some data, and then stalls.
|
||||
pub fn stalling_server() -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
|
||||
/// Fake server/connector that reads some data, and then stalls for the given time before
|
||||
/// returning a response. If `None` is given the server will stall indefinitely.
|
||||
pub fn stalling_server(
|
||||
respond_after: Option<Duration>,
|
||||
) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
|
||||
async fn fake_server(
|
||||
mut body: Pin<&mut SdkBody>,
|
||||
_time: TickAdvanceTime,
|
||||
_sleep: TickAdvanceSleep,
|
||||
_: (),
|
||||
sleep: TickAdvanceSleep,
|
||||
respond_after: Option<Duration>,
|
||||
) -> HttpResponse {
|
||||
let mut times = 5;
|
||||
while times > 0 && poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() {
|
||||
times -= 1;
|
||||
}
|
||||
// never awake after this
|
||||
tracing::info!("stalling indefinitely");
|
||||
std::future::pending::<()>().await;
|
||||
unreachable!()
|
||||
|
||||
match respond_after {
|
||||
Some(delay) => {
|
||||
tracing::info!("stalling for {} seconds", delay.as_secs());
|
||||
sleep.sleep(delay).await;
|
||||
tracing::info!("returning delayed response");
|
||||
successful_response()
|
||||
}
|
||||
None => {
|
||||
// never awake after this
|
||||
tracing::info!("stalling indefinitely");
|
||||
std::future::pending::<()>().await;
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
fake_server!(FakeServerConnector, fake_server)
|
||||
fake_server!(
|
||||
FakeServerConnector,
|
||||
fake_server,
|
||||
Option<Duration>,
|
||||
respond_after
|
||||
)
|
||||
}
|
||||
|
||||
/// Fake server/connector that polls data after each period of time in the given
|
||||
|
@ -332,6 +412,57 @@ mod upload_test_tools {
|
|||
)
|
||||
}
|
||||
|
||||
/// Fake server/connector that polls data only up to the content-length. Optionally delays
|
||||
/// sending the response by the given duration.
|
||||
pub fn limited_read_server(
|
||||
content_len: usize,
|
||||
respond_after: Option<Duration>,
|
||||
) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
|
||||
async fn fake_server(
|
||||
mut body: Pin<&mut SdkBody>,
|
||||
_time: TickAdvanceTime,
|
||||
sleep: TickAdvanceSleep,
|
||||
params: (usize, Option<Duration>),
|
||||
) -> HttpResponse {
|
||||
let mut remaining = params.0;
|
||||
loop {
|
||||
match poll_fn(|cx| body.as_mut().poll_data(cx)).await {
|
||||
Some(res) => {
|
||||
let rc = res.unwrap().len();
|
||||
remaining -= rc;
|
||||
tracing::info!("read {rc} bytes; remaining: {remaining}");
|
||||
if remaining == 0 {
|
||||
tracing::info!("read reported content-length data, stopping polling");
|
||||
break;
|
||||
};
|
||||
}
|
||||
None => {
|
||||
tracing::info!(
|
||||
"read until poll_data() returned None, no data left, stopping polling"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let respond_after = params.1;
|
||||
if let Some(delay) = respond_after {
|
||||
tracing::info!("stalling for {} seconds", delay.as_secs());
|
||||
sleep.sleep(delay).await;
|
||||
tracing::info!("returning delayed response");
|
||||
}
|
||||
|
||||
successful_response()
|
||||
}
|
||||
|
||||
fake_server!(
|
||||
FakeServerConnector,
|
||||
fake_server,
|
||||
(usize, Option<Duration>),
|
||||
(content_len, respond_after)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn expect_timeout(result: Result<StatusCode, SdkError<Infallible, Response<SdkBody>>>) {
|
||||
let err = result.expect_err("should have timed out");
|
||||
assert_str_contains!(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "aws-smithy-types"
|
||||
version = "1.1.9"
|
||||
version = "1.1.10"
|
||||
authors = [
|
||||
"AWS Rust SDK Team <aws-sdk-rust@amazon.com>",
|
||||
"Russell Cohen <rcoh@amazon.com>",
|
||||
|
@ -67,6 +67,7 @@ tokio = { version = "1.23.1", features = [
|
|||
"fs",
|
||||
"io-util",
|
||||
] }
|
||||
# This is used in a doctest, don't listen to udeps.
|
||||
tokio-stream = "0.1.5"
|
||||
tempfile = "3.2.0"
|
||||
|
||||
|
|
|
@ -376,10 +376,12 @@ mod test {
|
|||
async fn http_body_consumes_data() {
|
||||
let mut body = SdkBody::from("hello!");
|
||||
let mut body = Pin::new(&mut body);
|
||||
assert!(!body.is_end_stream());
|
||||
let data = body.next().await;
|
||||
assert!(data.is_some());
|
||||
let data = body.next().await;
|
||||
assert!(data.is_none());
|
||||
assert!(body.is_end_stream());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
@ -579,11 +579,13 @@ impl Inner {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(all(test, feature = "rt-tokio"))]
|
||||
mod tests {
|
||||
use super::{ByteStream, Inner};
|
||||
use crate::body::SdkBody;
|
||||
use crate::byte_stream::Inner;
|
||||
use bytes::Bytes;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_from_string_body() {
|
||||
|
@ -598,10 +600,8 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
#[tokio::test]
|
||||
async fn bytestream_into_async_read() {
|
||||
use super::ByteStream;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
|
||||
let byte_stream = ByteStream::from_static(b"data 1\ndata 2\ndata 3");
|
||||
|
@ -614,4 +614,44 @@ mod tests {
|
|||
assert_eq!(lines.next_line().await.unwrap(), Some("data 3".to_owned()));
|
||||
assert_eq!(lines.next_line().await.unwrap(), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_size_hint() {
|
||||
assert_eq!(ByteStream::from_static(b"hello").size_hint().1, Some(5));
|
||||
assert_eq!(ByteStream::from_static(b"").size_hint().1, Some(0));
|
||||
|
||||
let mut f = NamedTempFile::new().unwrap();
|
||||
f.write_all(b"hello").unwrap();
|
||||
let body = ByteStream::from_path(f.path()).await.unwrap();
|
||||
assert_eq!(body.inner.size_hint().1, Some(5));
|
||||
|
||||
let mut f = NamedTempFile::new().unwrap();
|
||||
f.write_all(b"").unwrap();
|
||||
let body = ByteStream::from_path(f.path()).await.unwrap();
|
||||
assert_eq!(body.inner.size_hint().1, Some(0));
|
||||
}
|
||||
|
||||
#[allow(clippy::bool_assert_comparison)]
|
||||
#[tokio::test]
|
||||
async fn valid_eos() {
|
||||
assert_eq!(
|
||||
ByteStream::from_static(b"hello").inner.body.is_end_stream(),
|
||||
false
|
||||
);
|
||||
let mut f = NamedTempFile::new().unwrap();
|
||||
f.write_all(b"hello").unwrap();
|
||||
let body = ByteStream::from_path(f.path()).await.unwrap();
|
||||
assert_eq!(body.inner.body.content_length(), Some(5));
|
||||
assert!(!body.inner.body.is_end_stream());
|
||||
|
||||
assert_eq!(
|
||||
ByteStream::from_static(b"").inner.body.is_end_stream(),
|
||||
true
|
||||
);
|
||||
let mut f = NamedTempFile::new().unwrap();
|
||||
f.write_all(b"").unwrap();
|
||||
let body = ByteStream::from_path(f.path()).await.unwrap();
|
||||
assert_eq!(body.inner.body.content_length(), Some(0));
|
||||
assert!(body.inner.body.is_end_stream());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,10 @@ impl PathBody {
|
|||
|
||||
fn from_file(file: File, length: u64, buffer_size: usize) -> Self {
|
||||
PathBody {
|
||||
state: State::Loaded(ReaderStream::with_capacity(file.take(length), buffer_size)),
|
||||
state: State::Loaded {
|
||||
stream: ReaderStream::with_capacity(file.take(length), buffer_size),
|
||||
bytes_left: length,
|
||||
},
|
||||
length,
|
||||
buffer_size,
|
||||
// The file used to create this `PathBody` should have already had an offset applied
|
||||
|
@ -230,7 +233,10 @@ impl FsBuilder {
|
|||
enum State {
|
||||
Unloaded(PathBuf),
|
||||
Loading(Pin<Box<dyn Future<Output = io::Result<File>> + Send + Sync + 'static>>),
|
||||
Loaded(ReaderStream<io::Take<File>>),
|
||||
Loaded {
|
||||
stream: ReaderStream<io::Take<File>>,
|
||||
bytes_left: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl http_body_0_4::Body for PathBody {
|
||||
|
@ -238,7 +244,7 @@ impl http_body_0_4::Body for PathBody {
|
|||
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
|
||||
fn poll_data(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
use std::task::Poll;
|
||||
|
@ -260,18 +266,27 @@ impl http_body_0_4::Body for PathBody {
|
|||
State::Loading(ref mut future) => {
|
||||
match futures_core::ready!(Pin::new(future).poll(cx)) {
|
||||
Ok(file) => {
|
||||
self.state = State::Loaded(ReaderStream::with_capacity(
|
||||
file.take(self.length),
|
||||
self.buffer_size,
|
||||
));
|
||||
self.state = State::Loaded {
|
||||
stream: ReaderStream::with_capacity(
|
||||
file.take(self.length),
|
||||
self.buffer_size,
|
||||
),
|
||||
bytes_left: self.length,
|
||||
};
|
||||
}
|
||||
Err(e) => return Poll::Ready(Some(Err(e.into()))),
|
||||
};
|
||||
}
|
||||
State::Loaded(ref mut stream) => {
|
||||
State::Loaded {
|
||||
ref mut stream,
|
||||
ref mut bytes_left,
|
||||
} => {
|
||||
use futures_core::Stream;
|
||||
return match futures_core::ready!(std::pin::Pin::new(stream).poll_next(cx)) {
|
||||
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
|
||||
return match futures_core::ready!(Pin::new(stream).poll_next(cx)) {
|
||||
Some(Ok(bytes)) => {
|
||||
*bytes_left -= bytes.len() as u64;
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
None => Poll::Ready(None),
|
||||
Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
|
||||
};
|
||||
|
@ -281,15 +296,17 @@ impl http_body_0_4::Body for PathBody {
|
|||
}
|
||||
|
||||
fn poll_trailers(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<Option<http::HeaderMap>, Self::Error>> {
|
||||
std::task::Poll::Ready(Ok(None))
|
||||
}
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
// fast path end-stream for empty streams
|
||||
self.length == 0
|
||||
match self.state {
|
||||
State::Unloaded(_) | State::Loading(_) => self.length == 0,
|
||||
State::Loaded { bytes_left, .. } => bytes_left == 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> http_body_0_4::SizeHint {
|
||||
|
@ -303,6 +320,7 @@ mod test {
|
|||
use super::FsBuilder;
|
||||
use crate::byte_stream::{ByteStream, Length};
|
||||
use bytes::Buf;
|
||||
use http_body_0_4::Body;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
|
@ -370,6 +388,29 @@ mod test {
|
|||
assert_eq!(body.content_length(), Some(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fsbuilder_is_end_stream() {
|
||||
let sentence = "A very long sentence that's clearly longer than a single byte.";
|
||||
let mut file = NamedTempFile::new().unwrap();
|
||||
file.write_all(sentence.as_bytes()).unwrap();
|
||||
// Ensure that the file was written to
|
||||
file.flush().expect("flushing is OK");
|
||||
|
||||
let mut body = FsBuilder::new()
|
||||
.path(&file)
|
||||
.build()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_inner();
|
||||
|
||||
assert!(!body.is_end_stream());
|
||||
assert_eq!(body.content_length(), Some(sentence.len() as u64));
|
||||
|
||||
let data = body.data().await.unwrap().unwrap();
|
||||
assert_eq!(data.len(), sentence.len());
|
||||
assert!(body.is_end_stream());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fsbuilder_respects_length() {
|
||||
let mut file = NamedTempFile::new().unwrap();
|
||||
|
|
Loading…
Reference in New Issue