disable stalled stream protection on empty bodies and after read complete (#3644)

## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
* https://github.com/awslabs/aws-sdk-rust/issues/1141
* https://github.com/awslabs/aws-sdk-rust/issues/1146
* https://github.com/awslabs/aws-sdk-rust/issues/1148


## Description
* Disables stalled stream upload protection for requests with an
empty/zero length body.
* Disables stalled stream upload throughput checking once the request
body has been read and handed off to the HTTP layer.


## Testing
Additional integration tests added covering empty bodies and completed
uploads.

Tested SQS issue against latest runtime and can see it works now. The S3
`CopyObject` issue is related to downloads and will need a different
solution.

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: Zelda Hessler <zhessler@amazon.com>
Co-authored-by: ysaito1001 <awsaito@amazon.com>
This commit is contained in:
Aaron Todd 2024-05-17 16:17:22 -04:00 committed by GitHub
parent d755bd2cd9
commit f0ddc666d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 341 additions and 51 deletions

View File

@ -36,6 +36,18 @@ references = ["aws-sdk-rust#1079"]
meta = { "breaking" = false, "bug" = true, "tada" = false }
author = "rcoh"
[[aws-sdk-rust]]
message = "Fixes stalled upload stream protection to not apply to empty request bodies and to stop checking for violations once the request body has been read."
references = ["aws-sdk-rust#1141", "aws-sdk-rust#1146", "aws-sdk-rust#1148"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
authors = ["aajtodd", "Velfi"]
[[smithy-rs]]
message = "Fixes stalled upload stream protection to not apply to empty request bodies and to stop checking for violations once the request body has been read."
references = ["aws-sdk-rust#1141", "aws-sdk-rust#1146", "aws-sdk-rust#1148"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
authors = ["aajtodd", "Velfi"]
[[aws-sdk-rust]]
message = "Updating the documentation for the `app_name` method on `ConfigLoader` to indicate the order of precedence for the sources of the `AppName`."
references = ["smithy-rs#3645"]

View File

@ -1,6 +1,6 @@
[package]
name = "aws-smithy-runtime-api"
version = "1.6.0"
version = "1.6.1"
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Zelda Hessler <zhessler@amazon.com>"]
description = "Smithy runtime types."
edition = "2021"

View File

@ -13,7 +13,11 @@
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use std::time::Duration;
const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5);
/// The default grace period for stalled stream protection.
///
/// When a stream stalls for longer than this grace period, the stream will
/// return an error.
pub const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(20);
/// Configuration for stalled stream protection.
///

View File

@ -1,6 +1,6 @@
[package]
name = "aws-smithy-runtime"
version = "1.5.2"
version = "1.5.3"
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Zelda Hessler <zhessler@amazon.com>"]
description = "The new smithy runtime crate"
edition = "2021"

View File

@ -136,6 +136,10 @@ impl UploadThroughput {
self.logs.lock().unwrap().push_bytes_transferred(now, bytes);
}
pub(crate) fn mark_complete(&self) -> bool {
self.logs.lock().unwrap().mark_complete()
}
pub(crate) fn report(&self, now: SystemTime) -> ThroughputReport {
self.logs.lock().unwrap().report(now)
}
@ -177,6 +181,8 @@ trait UploadReport {
impl UploadReport for ThroughputReport {
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) {
let throughput = match self {
// stream has been exhausted, stop tracking violations
ThroughputReport::Complete => return (false, ZERO_THROUGHPUT),
// If the report is incomplete, then we don't have enough data yet to
// decide if minimum throughput was violated.
ThroughputReport::Incomplete => {

View File

@ -22,6 +22,7 @@ trait DownloadReport {
impl DownloadReport for ThroughputReport {
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) {
let throughput = match self {
ThroughputReport::Complete => return (false, ZERO_THROUGHPUT),
// If the report is incomplete, then we don't have enough data yet to
// decide if minimum throughput was violated.
ThroughputReport::Incomplete => {
@ -175,6 +176,18 @@ where
tracing::trace!("received data: {}", bytes.len());
this.throughput
.push_bytes_transferred(now, bytes.len() as u64);
// hyper will optimistically stop polling when end of stream is reported
// (e.g. when content-length amount of data has been consumed) which means
// we may never get to `Poll:Ready(None)`. Check for same condition and
// attempt to stop checking throughput violations _now_ as we may never
// get polled again. The caveat here is that it depends on `Body` implementations
// implementing `is_end_stream()` correctly. Users can also disable SSP as an
// alternative for such fringe use cases.
if self.is_end_stream() {
tracing::trace!("stream reported end of stream before Poll::Ready(None) reached; marking stream complete");
self.throughput.mark_complete();
}
Poll::Ready(Some(Ok(bytes)))
}
Poll::Pending => {
@ -183,7 +196,12 @@ where
Poll::Pending
}
// If we've read all the data or an error occurred, then return that result.
res => res,
res => {
if this.throughput.mark_complete() {
tracing::trace!("stream completed: {:?}", res);
}
res
}
}
}

View File

@ -4,10 +4,12 @@
*/
use super::Throughput;
use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
use aws_smithy_runtime_api::client::stalled_stream_protection::{
StalledStreamProtectionConfig, DEFAULT_GRACE_PERIOD,
};
use std::time::Duration;
/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputBody).
/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputDownloadBody).
#[derive(Debug, Clone)]
pub struct MinimumThroughputBodyOptions {
/// The minimum throughput that is acceptable.
@ -69,6 +71,13 @@ impl MinimumThroughputBodyOptions {
}
}
const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput {
bytes_read: 1,
per_time_elapsed: Duration::from_secs(1),
};
const DEFAULT_CHECK_WINDOW: Duration = Duration::from_secs(1);
impl Default for MinimumThroughputBodyOptions {
fn default() -> Self {
Self {
@ -87,14 +96,6 @@ pub struct MinimumThroughputBodyOptionsBuilder {
grace_period: Option<Duration>,
}
const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0);
const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput {
bytes_read: 1,
per_time_elapsed: Duration::from_secs(1),
};
const DEFAULT_CHECK_WINDOW: Duration = Duration::from_secs(1);
impl MinimumThroughputBodyOptionsBuilder {
/// Create a new `MinimumThroughputBodyOptionsBuilder`.
pub fn new() -> Self {

View File

@ -260,6 +260,8 @@ pub(crate) enum ThroughputReport {
Pending,
/// The stream transferred this amount of throughput during the time window.
Transferred(Throughput),
/// The stream has completed, no more data is expected.
Complete,
}
const BIN_COUNT: usize = 10;
@ -285,6 +287,7 @@ pub(super) struct ThroughputLogs {
resolution: Duration,
current_tail: SystemTime,
buffer: LogBuffer<BIN_COUNT>,
stream_complete: bool,
}
impl ThroughputLogs {
@ -302,6 +305,7 @@ impl ThroughputLogs {
resolution,
current_tail: now,
buffer: LogBuffer::new(),
stream_complete: false,
}
}
@ -343,8 +347,24 @@ impl ThroughputLogs {
assert!(self.current_tail >= now);
}
/// Mark the stream complete indicating no more data is expected. This is an
/// idempotent operation -- subsequent invocations of this function have no effect
/// and return false.
///
/// After marking a stream complete [report](#method.report) will forever more return
/// [ThroughputReport::Complete]
pub(super) fn mark_complete(&mut self) -> bool {
let prev = self.stream_complete;
self.stream_complete = true;
!prev
}
/// Generates an overall report of the time window.
pub(super) fn report(&mut self, now: SystemTime) -> ThroughputReport {
if self.stream_complete {
return ThroughputReport::Complete;
}
self.catch_up(now);
self.buffer.fill_gaps();

View File

@ -65,6 +65,12 @@ impl Intercept for StalledStreamProtectionInterceptor {
) -> Result<(), BoxError> {
if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
if sspcfg.upload_enabled() {
if let Some(0) = context.request().body().content_length() {
tracing::trace!(
"skipping stalled stream protection for zero length request body"
);
return Ok(());
}
let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
let now = time_source.now();

View File

@ -105,9 +105,13 @@ async fn download_stalls() {
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let barrier = Arc::new(Barrier::new(2));
let c = barrier.clone();
let server = tokio::spawn(async move {
for _ in 1..10 {
c.wait().await;
for i in 1..10 {
tracing::debug!("send {i}");
response_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
@ -115,7 +119,10 @@ async fn download_stalls() {
});
let response_body = op.invoke(()).await.expect("initial success");
let result = tokio::spawn(eagerly_consume(response_body));
let result = tokio::spawn(async move {
barrier.wait().await;
eagerly_consume(response_body).await
});
server.await.unwrap();
let err = result
@ -188,6 +195,7 @@ async fn user_downloads_data_too_slowly() {
}
use download_test_tools::*;
use tokio::sync::Barrier;
mod download_test_tools {
use crate::stalled_stream_common::*;

View File

@ -7,6 +7,8 @@
#[macro_use]
mod stalled_stream_common;
use aws_smithy_runtime_api::client::stalled_stream_protection::DEFAULT_GRACE_PERIOD;
use stalled_stream_common::*;
/// Scenario: Successful upload at a rate above the minimum throughput.
@ -88,7 +90,7 @@ async fn upload_too_slow() {
async fn upload_stalls() {
let _logs = show_test_logs();
let (server, time, sleep) = stalling_server();
let (server, time, sleep) = stalling_server(None);
let op = operation(server, time.clone(), sleep);
let (body, body_sender) = channel_body();
@ -107,27 +109,84 @@ async fn upload_stalls() {
expect_timeout(result.await.expect("no panics"));
}
/// Scenario: All the request data is either uploaded to the server or buffered in the
/// HTTP client, but the response doesn't start coming through within the grace period.
/// Expected: MUST timeout after the grace period completes.
/// Scenario: Request does not have a body. Server response doesn't start coming through
/// until after the grace period.
/// Expected: MUST NOT timeout.
#[tokio::test]
async fn complete_upload_no_response() {
async fn empty_request_body_delayed_response() {
let _logs = show_test_logs();
let (server, time, sleep) = stalling_server();
let (server, time, sleep) = stalling_server(Some(Duration::from_secs(6)));
let op = operation(server, time.clone(), sleep);
let result = tokio::spawn(async move { op.invoke(SdkBody::empty()).await });
let _advance = tokio::spawn(async move {
for _ in 0..6 {
tick!(time, Duration::from_secs(1));
}
});
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
/// Scenario: All the request data is either uploaded to the server or buffered in the
/// HTTP client, but the response doesn't start coming through within the grace period.
/// Expected: MUST NOT timeout, upload throughput should only apply up until the request body has
/// been read completely and handed off to the HTTP client.
#[tokio::test]
async fn complete_upload_delayed_response() {
let _logs = show_test_logs();
let (server, time, sleep) = stalling_server(Some(Duration::from_secs(6)));
let op = operation(server, time.clone(), sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
info!("send data");
body_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
info!("body send complete; dropping");
drop(body_sender);
time.tick(Duration::from_secs(6)).await;
tick!(time, DEFAULT_GRACE_PERIOD);
info!("body stream task complete");
// advance to unblock the stalled server
tick!(time, Duration::from_secs(2));
});
expect_timeout(result.await.expect("no panics"));
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
/// Scenario: Upload all request data and never poll again once content-length has
/// been reached. Hyper will stop polling once it detects end of stream so we can't rely
/// on reaching `Poll:Ready(None)` to detect end of stream.
///
/// ref: https://github.com/hyperium/hyper/issues/1545
/// ref: https://github.com/hyperium/hyper/issues/1521
///
/// Expected: MUST NOT timeout, upload throughput should only apply up until the request body has
/// been read completely. Once no more data is expected we should stop checking for throughput
/// violations.
#[tokio::test]
async fn complete_upload_stop_polling() {
let _logs = show_test_logs();
let (server, time, sleep) = limited_read_server(NEAT_DATA.len(), Some(Duration::from_secs(7)));
let op = operation(server, time.clone(), sleep.clone());
let body = SdkBody::from(NEAT_DATA);
let result = tokio::spawn(async move { op.invoke(body).await });
tokio::spawn(async move {
// advance past the grace period
tick!(time, DEFAULT_GRACE_PERIOD + Duration::from_secs(1));
// unblock server
tick!(time, Duration::from_secs(2));
});
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
// Scenario: The server stops asking for data, the client maxes out its send buffer,
@ -189,6 +248,8 @@ async fn user_provides_data_too_slowly() {
use upload_test_tools::*;
mod upload_test_tools {
use aws_smithy_async::rt::sleep::AsyncSleep;
use crate::stalled_stream_common::*;
pub fn successful_response() -> HttpResponse {
@ -285,24 +346,43 @@ mod upload_test_tools {
fake_server!(FakeServerConnector, fake_server, bool, advance_time)
}
/// Fake server/connector that reads some data, and then stalls.
pub fn stalling_server() -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
/// Fake server/connector that reads some data, and then stalls for the given time before
/// returning a response. If `None` is given the server will stall indefinitely.
pub fn stalling_server(
respond_after: Option<Duration>,
) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
async fn fake_server(
mut body: Pin<&mut SdkBody>,
_time: TickAdvanceTime,
_sleep: TickAdvanceSleep,
_: (),
sleep: TickAdvanceSleep,
respond_after: Option<Duration>,
) -> HttpResponse {
let mut times = 5;
while times > 0 && poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() {
times -= 1;
}
// never awake after this
tracing::info!("stalling indefinitely");
std::future::pending::<()>().await;
unreachable!()
match respond_after {
Some(delay) => {
tracing::info!("stalling for {} seconds", delay.as_secs());
sleep.sleep(delay).await;
tracing::info!("returning delayed response");
successful_response()
}
None => {
// never awake after this
tracing::info!("stalling indefinitely");
std::future::pending::<()>().await;
unreachable!()
}
}
}
fake_server!(FakeServerConnector, fake_server)
fake_server!(
FakeServerConnector,
fake_server,
Option<Duration>,
respond_after
)
}
/// Fake server/connector that polls data after each period of time in the given
@ -332,6 +412,57 @@ mod upload_test_tools {
)
}
/// Fake server/connector that polls data only up to the content-length. Optionally delays
/// sending the response by the given duration.
pub fn limited_read_server(
content_len: usize,
respond_after: Option<Duration>,
) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
async fn fake_server(
mut body: Pin<&mut SdkBody>,
_time: TickAdvanceTime,
sleep: TickAdvanceSleep,
params: (usize, Option<Duration>),
) -> HttpResponse {
let mut remaining = params.0;
loop {
match poll_fn(|cx| body.as_mut().poll_data(cx)).await {
Some(res) => {
let rc = res.unwrap().len();
remaining -= rc;
tracing::info!("read {rc} bytes; remaining: {remaining}");
if remaining == 0 {
tracing::info!("read reported content-length data, stopping polling");
break;
};
}
None => {
tracing::info!(
"read until poll_data() returned None, no data left, stopping polling"
);
break;
}
}
}
let respond_after = params.1;
if let Some(delay) = respond_after {
tracing::info!("stalling for {} seconds", delay.as_secs());
sleep.sleep(delay).await;
tracing::info!("returning delayed response");
}
successful_response()
}
fake_server!(
FakeServerConnector,
fake_server,
(usize, Option<Duration>),
(content_len, respond_after)
)
}
pub fn expect_timeout(result: Result<StatusCode, SdkError<Infallible, Response<SdkBody>>>) {
let err = result.expect_err("should have timed out");
assert_str_contains!(

View File

@ -1,6 +1,6 @@
[package]
name = "aws-smithy-types"
version = "1.1.9"
version = "1.1.10"
authors = [
"AWS Rust SDK Team <aws-sdk-rust@amazon.com>",
"Russell Cohen <rcoh@amazon.com>",
@ -67,6 +67,7 @@ tokio = { version = "1.23.1", features = [
"fs",
"io-util",
] }
# This is used in a doctest, don't listen to udeps.
tokio-stream = "0.1.5"
tempfile = "3.2.0"

View File

@ -376,10 +376,12 @@ mod test {
async fn http_body_consumes_data() {
let mut body = SdkBody::from("hello!");
let mut body = Pin::new(&mut body);
assert!(!body.is_end_stream());
let data = body.next().await;
assert!(data.is_some());
let data = body.next().await;
assert!(data.is_none());
assert!(body.is_end_stream());
}
#[tokio::test]

View File

@ -579,11 +579,13 @@ impl Inner {
}
}
#[cfg(test)]
#[cfg(all(test, feature = "rt-tokio"))]
mod tests {
use super::{ByteStream, Inner};
use crate::body::SdkBody;
use crate::byte_stream::Inner;
use bytes::Bytes;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn read_from_string_body() {
@ -598,10 +600,8 @@ mod tests {
);
}
#[cfg(feature = "rt-tokio")]
#[tokio::test]
async fn bytestream_into_async_read() {
use super::ByteStream;
use tokio::io::AsyncBufReadExt;
let byte_stream = ByteStream::from_static(b"data 1\ndata 2\ndata 3");
@ -614,4 +614,44 @@ mod tests {
assert_eq!(lines.next_line().await.unwrap(), Some("data 3".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), None);
}
#[tokio::test]
async fn valid_size_hint() {
assert_eq!(ByteStream::from_static(b"hello").size_hint().1, Some(5));
assert_eq!(ByteStream::from_static(b"").size_hint().1, Some(0));
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"hello").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.size_hint().1, Some(5));
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.size_hint().1, Some(0));
}
#[allow(clippy::bool_assert_comparison)]
#[tokio::test]
async fn valid_eos() {
assert_eq!(
ByteStream::from_static(b"hello").inner.body.is_end_stream(),
false
);
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"hello").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.body.content_length(), Some(5));
assert!(!body.inner.body.is_end_stream());
assert_eq!(
ByteStream::from_static(b"").inner.body.is_end_stream(),
true
);
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.body.content_length(), Some(0));
assert!(body.inner.body.is_end_stream());
}
}

View File

@ -44,7 +44,10 @@ impl PathBody {
fn from_file(file: File, length: u64, buffer_size: usize) -> Self {
PathBody {
state: State::Loaded(ReaderStream::with_capacity(file.take(length), buffer_size)),
state: State::Loaded {
stream: ReaderStream::with_capacity(file.take(length), buffer_size),
bytes_left: length,
},
length,
buffer_size,
// The file used to create this `PathBody` should have already had an offset applied
@ -230,7 +233,10 @@ impl FsBuilder {
enum State {
Unloaded(PathBuf),
Loading(Pin<Box<dyn Future<Output = io::Result<File>> + Send + Sync + 'static>>),
Loaded(ReaderStream<io::Take<File>>),
Loaded {
stream: ReaderStream<io::Take<File>>,
bytes_left: u64,
},
}
impl http_body_0_4::Body for PathBody {
@ -238,7 +244,7 @@ impl http_body_0_4::Body for PathBody {
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
fn poll_data(
mut self: std::pin::Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Data, Self::Error>>> {
use std::task::Poll;
@ -260,18 +266,27 @@ impl http_body_0_4::Body for PathBody {
State::Loading(ref mut future) => {
match futures_core::ready!(Pin::new(future).poll(cx)) {
Ok(file) => {
self.state = State::Loaded(ReaderStream::with_capacity(
file.take(self.length),
self.buffer_size,
));
self.state = State::Loaded {
stream: ReaderStream::with_capacity(
file.take(self.length),
self.buffer_size,
),
bytes_left: self.length,
};
}
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
}
State::Loaded(ref mut stream) => {
State::Loaded {
ref mut stream,
ref mut bytes_left,
} => {
use futures_core::Stream;
return match futures_core::ready!(std::pin::Pin::new(stream).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
return match futures_core::ready!(Pin::new(stream).poll_next(cx)) {
Some(Ok(bytes)) => {
*bytes_left -= bytes.len() as u64;
Poll::Ready(Some(Ok(bytes)))
}
None => Poll::Ready(None),
Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
};
@ -281,15 +296,17 @@ impl http_body_0_4::Body for PathBody {
}
fn poll_trailers(
self: std::pin::Pin<&mut Self>,
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<Option<http::HeaderMap>, Self::Error>> {
std::task::Poll::Ready(Ok(None))
}
fn is_end_stream(&self) -> bool {
// fast path end-stream for empty streams
self.length == 0
match self.state {
State::Unloaded(_) | State::Loading(_) => self.length == 0,
State::Loaded { bytes_left, .. } => bytes_left == 0,
}
}
fn size_hint(&self) -> http_body_0_4::SizeHint {
@ -303,6 +320,7 @@ mod test {
use super::FsBuilder;
use crate::byte_stream::{ByteStream, Length};
use bytes::Buf;
use http_body_0_4::Body;
use std::io::Write;
use tempfile::NamedTempFile;
@ -370,6 +388,29 @@ mod test {
assert_eq!(body.content_length(), Some(1));
}
#[tokio::test]
async fn fsbuilder_is_end_stream() {
let sentence = "A very long sentence that's clearly longer than a single byte.";
let mut file = NamedTempFile::new().unwrap();
file.write_all(sentence.as_bytes()).unwrap();
// Ensure that the file was written to
file.flush().expect("flushing is OK");
let mut body = FsBuilder::new()
.path(&file)
.build()
.await
.unwrap()
.into_inner();
assert!(!body.is_end_stream());
assert_eq!(body.content_length(), Some(sentence.len() as u64));
let data = body.data().await.unwrap().unwrap();
assert_eq!(data.len(), sentence.len());
assert!(body.is_end_stream());
}
#[tokio::test]
async fn fsbuilder_respects_length() {
let mut file = NamedTempFile::new().unwrap();