Overhaul stalled stream protection and add upload support (#3485)

This PR overhauls the existing stalled stream protection with a new
algorithm, and also adds support for minimum throughput on upload
streams. The new algorithm adds support for differentiating between the
user or the server causing the stall, and not timing out if it's the
user causing the stall. This will fix timeout issues when a customer
makes remote service calls in between streaming pieces of information.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
This commit is contained in:
John DiSanti 2024-03-27 13:33:44 -07:00 committed by GitHub
parent 5c9379574d
commit 27834ae2cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 2023 additions and 617 deletions

View File

@ -28,3 +28,45 @@ message = "Make `BehaviorVersion` be future-proof by disallowing it to be constr
references = ["aws-sdk-rust#1111", "smithy-rs#3513"]
meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "client" }
author = "Ten0"
[[smithy-rs]]
message = """
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:
```rust
let config = my_service::Config::builder()
.stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
// ...
.build();
```
"""
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
authors = ["jdisanti"]
[[aws-sdk-rust]]
message = """
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:
```rust
let config = aws_config::defaults(BehaviorVersion::latest())
.stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
.load()
.await;
```
"""
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"
[[smithy-rs]]
message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit."
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
authors = ["jdisanti"]
[[aws-sdk-rust]]
message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit."
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "jdisanti"

View File

@ -48,3 +48,6 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter", "json"] }
# If you're writing a test with this, take heed! `no-env-filter` means you'll be capturing
# logs from everything that speaks, so be specific with your asserts.
tracing-test = { version = "0.2.4", features = ["no-env-filter"] }
[dependencies]
pin-project-lite = "0.2.13"

View File

@ -0,0 +1,89 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Body wrappers must pass through size_hint
use aws_config::SdkConfig;
use aws_sdk_s3::{
config::{Credentials, Region, SharedCredentialsProvider},
primitives::{ByteStream, SdkBody},
Client,
};
use aws_smithy_runtime::client::http::test_util::{capture_request, infallible_client_fn};
use http_body::Body;
#[tokio::test]
async fn download_body_size_hint_check() {
let test_body_content = b"hello";
let test_body = || SdkBody::from(&test_body_content[..]);
assert_eq!(
Some(test_body_content.len() as u64),
(test_body)().size_hint().exact(),
"pre-condition check"
);
let http_client = infallible_client_fn(move |_| {
http::Response::builder()
.status(200)
.body((test_body)())
.unwrap()
});
let sdk_config = SdkConfig::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-east-1"))
.http_client(http_client)
.build();
let client = Client::new(&sdk_config);
let response = client
.get_object()
.bucket("foo")
.key("foo")
.send()
.await
.unwrap();
assert_eq!(
(
test_body_content.len() as u64,
Some(test_body_content.len() as u64),
),
response.body.size_hint(),
"the size hint should be passed through all the default body wrappers"
);
}
#[tokio::test]
async fn upload_body_size_hint_check() {
let test_body_content = b"hello";
let (http_client, rx) = capture_request(None);
let sdk_config = SdkConfig::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-east-1"))
.http_client(http_client)
.build();
let client = Client::new(&sdk_config);
let body = ByteStream::from_static(test_body_content);
assert_eq!(
(
test_body_content.len() as u64,
Some(test_body_content.len() as u64),
),
body.size_hint(),
"pre-condition check"
);
let _response = client
.put_object()
.bucket("foo")
.key("foo")
.body(body)
.send()
.await;
let captured_request = rx.expect_request();
assert_eq!(
Some(test_body_content.len() as u64),
captured_request.body().size_hint().exact(),
"the size hint should be passed through all the default body wrappers"
);
}

View File

@ -4,27 +4,90 @@
*/
use aws_credential_types::Credentials;
use aws_sdk_s3::config::{Region, StalledStreamProtectionConfig};
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::{
config::{Region, StalledStreamProtectionConfig},
error::BoxError,
};
use aws_sdk_s3::{error::DisplayErrorContext, primitives::ByteStream};
use aws_sdk_s3::{Client, Config};
use bytes::BytesMut;
use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs};
use aws_smithy_types::body::SdkBody;
use bytes::{Bytes, BytesMut};
use http_body::Body;
use std::error::Error;
use std::future::Future;
use std::net::SocketAddr;
use std::time::Duration;
use std::{future::Future, task::Poll};
use std::{net::SocketAddr, pin::Pin, task::Context};
use tokio::{
net::{TcpListener, TcpStream},
time::sleep,
};
use tracing::debug;
// This test doesn't work because we can't count on `hyper` to poll the body,
// regardless of whether we schedule a wake. To make this functionality work,
// we'd have to integrate more closely with the orchestrator.
//
// I'll leave this test here because we do eventually want to support stalled
// stream protection for uploads.
#[ignore]
enum SlowBodyState {
Wait(Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>>),
Send,
Taken,
}
struct SlowBody {
state: SlowBodyState,
}
impl SlowBody {
fn new() -> Self {
Self {
state: SlowBodyState::Send,
}
}
}
impl Body for SlowBody {
type Data = Bytes;
type Error = BoxError;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
loop {
let mut state = SlowBodyState::Taken;
std::mem::swap(&mut state, &mut self.state);
match state {
SlowBodyState::Wait(mut fut) => match fut.as_mut().poll(cx) {
Poll::Ready(_) => self.state = SlowBodyState::Send,
Poll::Pending => {
self.state = SlowBodyState::Wait(fut);
return Poll::Pending;
}
},
SlowBodyState::Send => {
self.state = SlowBodyState::Wait(Box::pin(sleep(Duration::from_micros(100))));
return Poll::Ready(Some(Ok(Bytes::from_static(
b"data_data_data_data_data_data_data_data_data_data_data_data_\
data_data_data_data_data_data_data_data_data_data_data_data_\
data_data_data_data_data_data_data_data_data_data_data_data_\
data_data_data_data_data_data_data_data_data_data_data_data_",
))));
}
SlowBodyState::Taken => unreachable!(),
}
}
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}
#[tokio::test]
async fn test_stalled_stream_protection_defaults_for_upload() {
// We spawn a faulty server that will close the connection after
// writing half of the response body.
let _logs = capture_test_logs();
// We spawn a faulty server that will stop all request processing after reading half of the request body.
let (server, server_addr) = start_faulty_upload_server().await;
let _ = tokio::spawn(server);
@ -32,7 +95,8 @@ async fn test_stalled_stream_protection_defaults_for_upload() {
.credentials_provider(Credentials::for_tests())
.region(Region::new("us-east-1"))
.endpoint_url(format!("http://{server_addr}"))
// .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3510): make stalled stream protection enabled by default with BMV and remove this line
.stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
.build();
let client = Client::from_conf(conf);
@ -40,22 +104,19 @@ async fn test_stalled_stream_protection_defaults_for_upload() {
.put_object()
.bucket("a-test-bucket")
.key("stalled-stream-test.txt")
.body(ByteStream::from_static(b"Hello"))
.body(ByteStream::new(SdkBody::from_body_0_4(SlowBody::new())))
.send()
.await
.expect_err("upload stream stalled out");
let err = err.source().expect("inner error exists");
assert_eq!(
err.to_string(),
let err_msg = DisplayErrorContext(&err).to_string();
assert_str_contains!(
err_msg,
"minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed"
);
}
async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr) {
use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep;
let listener = TcpListener::bind("0.0.0.0:0")
.await
.expect("socket is free");
@ -65,12 +126,7 @@ async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr)
let mut buf = BytesMut::new();
let mut time_to_stall = false;
loop {
if time_to_stall {
debug!("faulty server has read partial request, now getting stuck");
break;
}
while !time_to_stall {
match socket.try_read_buf(&mut buf) {
Ok(0) => {
unreachable!(
@ -79,12 +135,7 @@ async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr)
}
Ok(n) => {
debug!("read {n} bytes from the socket");
// Check to see if we've received some headers
if buf.len() >= 128 {
let s = String::from_utf8_lossy(&buf);
debug!("{s}");
time_to_stall = true;
}
}
@ -98,6 +149,7 @@ async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr)
}
}
debug!("faulty server has read partial request, now getting stuck");
loop {
tokio::task::yield_now().await
}
@ -229,14 +281,11 @@ async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() {
err.to_string(),
"minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed"
);
// 1s check interval + 5s grace period
assert_eq!(start.elapsed().as_secs(), 6);
// the 1s check interval is included in the 5s grace period
assert_eq!(start.elapsed().as_secs(), 5);
}
async fn start_faulty_download_server() -> (impl Future<Output = ()>, SocketAddr) {
use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep;
let listener = TcpListener::bind("0.0.0.0:0")
.await
.expect("socket is free");

View File

@ -120,15 +120,12 @@ class StalledStreamProtectionOperationCustomization(
is OperationSection.AdditionalInterceptors -> {
val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection")
section.registerInterceptor(rc, this) {
// Currently, only response bodies are protected/supported because
// we can't count on hyper to poll a request body on wake.
rustTemplate(
"""
#{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody)
#{StalledStreamProtectionInterceptor}::default()
""",
*preludeScope,
"StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"),
"Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"),
)
}
}

View File

@ -20,15 +20,17 @@ const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5);
/// When enabled, download streams that stall out will be cancelled.
#[derive(Clone, Debug)]
pub struct StalledStreamProtectionConfig {
is_enabled: bool,
upload_enabled: bool,
download_enabled: bool,
grace_period: Duration,
}
impl StalledStreamProtectionConfig {
/// Create a new config that enables stalled stream protection.
/// Create a new config that enables stalled stream protection for both uploads and downloads.
pub fn enabled() -> Builder {
Builder {
is_enabled: Some(true),
upload_enabled: Some(true),
download_enabled: Some(true),
grace_period: None,
}
}
@ -36,14 +38,25 @@ impl StalledStreamProtectionConfig {
/// Create a new config that disables stalled stream protection.
pub fn disabled() -> Self {
Self {
is_enabled: false,
upload_enabled: false,
download_enabled: false,
grace_period: DEFAULT_GRACE_PERIOD,
}
}
/// Return whether stalled stream protection is enabled.
/// Return whether stalled stream protection is enabled for either uploads or downloads.
pub fn is_enabled(&self) -> bool {
self.is_enabled
self.upload_enabled || self.download_enabled
}
/// True if stalled stream protection is enabled for upload streams.
pub fn upload_enabled(&self) -> bool {
self.upload_enabled
}
/// True if stalled stream protection is enabled for download streams.
pub fn download_enabled(&self) -> bool {
self.download_enabled
}
/// Return the grace period for stalled stream protection.
@ -57,7 +70,8 @@ impl StalledStreamProtectionConfig {
#[derive(Clone, Debug)]
pub struct Builder {
is_enabled: Option<bool>,
upload_enabled: Option<bool>,
download_enabled: Option<bool>,
grace_period: Option<Duration>,
}
@ -74,22 +88,48 @@ impl Builder {
self
}
/// Set whether stalled stream protection is enabled.
pub fn is_enabled(mut self, is_enabled: bool) -> Self {
self.is_enabled = Some(is_enabled);
/// Set whether stalled stream protection is enabled for both uploads and downloads.
pub fn is_enabled(mut self, enabled: bool) -> Self {
self.set_is_enabled(Some(enabled));
self
}
/// Set whether stalled stream protection is enabled.
pub fn set_is_enabled(&mut self, is_enabled: Option<bool>) -> &mut Self {
self.is_enabled = is_enabled;
/// Set whether stalled stream protection is enabled for both uploads and downloads.
pub fn set_is_enabled(&mut self, enabled: Option<bool>) -> &mut Self {
self.set_upload_enabled(enabled);
self.set_download_enabled(enabled);
self
}
/// Set whether stalled stream protection is enabled for upload streams.
pub fn upload_enabled(mut self, enabled: bool) -> Self {
self.set_upload_enabled(Some(enabled));
self
}
/// Set whether stalled stream protection is enabled for upload streams.
pub fn set_upload_enabled(&mut self, enabled: Option<bool>) -> &mut Self {
self.upload_enabled = enabled;
self
}
/// Set whether stalled stream protection is enabled for download streams.
pub fn download_enabled(mut self, enabled: bool) -> Self {
self.set_download_enabled(Some(enabled));
self
}
/// Set whether stalled stream protection is enabled for download streams.
pub fn set_download_enabled(&mut self, enabled: Option<bool>) -> &mut Self {
self.download_enabled = enabled;
self
}
/// Build the config.
pub fn build(self) -> StalledStreamProtectionConfig {
StalledStreamProtectionConfig {
is_enabled: self.is_enabled.unwrap_or_default(),
upload_enabled: self.upload_enabled.unwrap_or_default(),
download_enabled: self.download_enabled.unwrap_or_default(),
grace_period: self.grace_period.unwrap_or(DEFAULT_GRACE_PERIOD),
}
}
@ -98,7 +138,8 @@ impl Builder {
impl From<StalledStreamProtectionConfig> for Builder {
fn from(config: StalledStreamProtectionConfig) -> Self {
Builder {
is_enabled: Some(config.is_enabled),
upload_enabled: Some(config.upload_enabled),
download_enabled: Some(config.download_enabled),
grace_period: Some(config.grace_period),
}
}

View File

@ -43,7 +43,7 @@ serde_json = { version = "1", features = ["preserve_order"], optional = true }
indexmap = { version = "2", optional = true, features = ["serde"] }
tokio = { version = "1.25", features = [] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", optional = true, features = ["fmt", "json"] }
tracing-subscriber = { version = "0.3.16", optional = true, features = ["env-filter", "fmt", "json"] }
[dev-dependencies]
approx = "0.5.1"

View File

@ -171,7 +171,16 @@ pub fn default_identity_cache_plugin() -> Option<SharedRuntimePlugin> {
///
/// By default, when throughput falls below 1/Bs for more than 5 seconds, the
/// stream is cancelled.
#[deprecated(
since = "1.2.0",
note = "This function wasn't intended to be public, and didn't take the behavior major version as an argument, so it couldn't be evolved over time."
)]
pub fn default_stalled_stream_protection_config_plugin() -> Option<SharedRuntimePlugin> {
default_stalled_stream_protection_config_plugin_v2(BehaviorVersion::v2023_11_09())
}
fn default_stalled_stream_protection_config_plugin_v2(
_behavior_version: BehaviorVersion,
) -> Option<SharedRuntimePlugin> {
Some(
default_plugin(
"default_stalled_stream_protection_config_plugin",
@ -184,6 +193,8 @@ pub fn default_stalled_stream_protection_config_plugin() -> Option<SharedRuntime
.with_config(layer("default_stalled_stream_protection_config", |layer| {
layer.store_put(
StalledStreamProtectionConfig::enabled()
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3510): enable behind new behavior version
.upload_enabled(false)
.grace_period(Duration::from_secs(5))
.build(),
);
@ -257,6 +268,10 @@ impl DefaultPluginParams {
pub fn default_plugins(
params: DefaultPluginParams,
) -> impl IntoIterator<Item = SharedRuntimePlugin> {
let behavior_version = params
.behavior_version
.unwrap_or_else(BehaviorVersion::latest);
[
default_http_client_plugin(),
default_identity_cache_plugin(),
@ -268,7 +283,7 @@ pub fn default_plugins(
default_sleep_impl_plugin(),
default_time_source_plugin(),
default_timeout_config_plugin(),
default_stalled_stream_protection_config_plugin(),
default_stalled_stream_protection_config_plugin_v2(behavior_version),
enforce_content_length_runtime_plugin(),
]
.into_iter()

View File

@ -15,25 +15,46 @@ pub mod options;
pub use throughput::Throughput;
mod throughput;
use crate::client::http::body::minimum_throughput::throughput::ThroughputReport;
use aws_smithy_async::rt::sleep::Sleep;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_async::time::{SharedTimeSource, TimeSource};
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_runtime_api::{
box_error::BoxError,
client::{
http::HttpConnectorFuture, result::ConnectorError, runtime_components::RuntimeComponents,
stalled_stream_protection::StalledStreamProtectionConfig,
},
};
use aws_smithy_runtime_api::{client::orchestrator::HttpResponse, shared::IntoShared};
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use options::MinimumThroughputBodyOptions;
use std::fmt;
use std::time::SystemTime;
use std::{
fmt,
sync::{Arc, Mutex},
task::Poll,
};
use std::{future::Future, pin::Pin};
use std::{
task::Context,
time::{Duration, SystemTime},
};
use throughput::ThroughputLogs;
/// Use [`MinimumThroughputDownloadBody`] instead.
#[deprecated(note = "Renamed to MinimumThroughputDownloadBody since it doesn't work for uploads")]
pub type MinimumThroughputBody<B> = MinimumThroughputDownloadBody<B>;
pin_project_lite::pin_project! {
/// A body-wrapping type that ensures data is being streamed faster than some lower limit.
///
/// If data is being streamed too slowly, this body type will emit an error next time it's polled.
pub struct MinimumThroughputBody<B> {
pub struct MinimumThroughputDownloadBody<B> {
async_sleep: SharedAsyncSleep,
time_source: SharedTimeSource,
options: MinimumThroughputBodyOptions,
throughput_logs: ThroughputLogs,
resolution: Duration,
#[pin]
sleep_fut: Option<Sleep>,
#[pin]
@ -43,10 +64,7 @@ pin_project_lite::pin_project! {
}
}
const SIZE_OF_ONE_LOG: usize = std::mem::size_of::<(SystemTime, u64)>(); // 24 bytes per log
const NUMBER_OF_LOGS_IN_ONE_KB: f64 = 1024.0 / SIZE_OF_ONE_LOG as f64;
impl<B> MinimumThroughputBody<B> {
impl<B> MinimumThroughputDownloadBody<B> {
/// Create a new minimum throughput body.
pub fn new(
time_source: impl TimeSource + 'static,
@ -54,14 +72,15 @@ impl<B> MinimumThroughputBody<B> {
body: B,
options: MinimumThroughputBodyOptions,
) -> Self {
let time_source: SharedTimeSource = time_source.into_shared();
let now = time_source.now();
let throughput_logs = ThroughputLogs::new(options.check_window(), now);
let resolution = throughput_logs.resolution();
Self {
throughput_logs: ThroughputLogs::new(
// Never keep more than 10KB of logs in memory. This currently
// equates to 426 logs.
(NUMBER_OF_LOGS_IN_ONE_KB * 10.0) as usize,
),
throughput_logs,
resolution,
async_sleep: async_sleep.into_shared(),
time_source: time_source.into_shared(),
time_source,
inner: body,
sleep_fut: None,
grace_period_fut: None,
@ -93,4 +112,286 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}
// Tests are implemented per HTTP body type.
/// Used to store the upload throughput in the interceptor context.
#[derive(Clone, Debug)]
pub(crate) struct UploadThroughput {
logs: Arc<Mutex<ThroughputLogs>>,
}
impl UploadThroughput {
pub(crate) fn new(time_window: Duration, now: SystemTime) -> Self {
Self {
logs: Arc::new(Mutex::new(ThroughputLogs::new(time_window, now))),
}
}
pub(crate) fn resolution(&self) -> Duration {
self.logs.lock().unwrap().resolution()
}
pub(crate) fn push_pending(&self, now: SystemTime) {
self.logs.lock().unwrap().push_pending(now);
}
pub(crate) fn push_bytes_transferred(&self, now: SystemTime, bytes: u64) {
self.logs.lock().unwrap().push_bytes_transferred(now, bytes);
}
pub(crate) fn report(&self, now: SystemTime) -> ThroughputReport {
self.logs.lock().unwrap().report(now)
}
}
impl Storable for UploadThroughput {
type Storer = StoreReplace<Self>;
}
pin_project_lite::pin_project! {
pub(crate) struct ThroughputReadingBody<B> {
time_source: SharedTimeSource,
throughput: UploadThroughput,
#[pin]
inner: B,
}
}
impl<B> ThroughputReadingBody<B> {
pub(crate) fn new(
time_source: SharedTimeSource,
throughput: UploadThroughput,
body: B,
) -> Self {
Self {
time_source,
throughput,
inner: body,
}
}
}
const ZERO_THROUGHPUT: Throughput = Throughput::new_bytes_per_second(0);
// Helper trait for interpretting the throughput report.
trait UploadReport {
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput);
}
impl UploadReport for ThroughputReport {
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) {
let throughput = match self {
// If the report is incomplete, then we don't have enough data yet to
// decide if minimum throughput was violated.
ThroughputReport::Incomplete => {
tracing::trace!(
"not enough data to decide if minimum throughput has been violated"
);
return (false, ZERO_THROUGHPUT);
}
// If most of the datapoints are Poll::Pending, then the user has stalled.
// In this case, we don't want to say minimum throughput was violated.
ThroughputReport::Pending => {
tracing::debug!(
"the user has stalled; this will not become a minimum throughput violation"
);
return (false, ZERO_THROUGHPUT);
}
// If there has been no polling, then the server has stalled. Alternatively,
// if we're transferring data, but it's too slow, then we also want to say
// that the minimum throughput has been violated.
ThroughputReport::NoPolling => ZERO_THROUGHPUT,
ThroughputReport::Transferred(tp) => tp,
};
if throughput < minimum_throughput {
tracing::debug!(
"current throughput: {throughput} is below minimum: {minimum_throughput}"
);
(true, throughput)
} else {
(false, throughput)
}
}
}
pin_project_lite::pin_project! {
/// Future that pairs with [`UploadThroughput`] to add a minimum throughput
/// requirement to a request upload stream.
struct UploadThroughputCheckFuture {
#[pin]
response: HttpConnectorFuture,
#[pin]
check_interval: Option<Sleep>,
#[pin]
grace_period: Option<Sleep>,
time_source: SharedTimeSource,
sleep_impl: SharedAsyncSleep,
upload_throughput: UploadThroughput,
resolution: Duration,
options: MinimumThroughputBodyOptions,
failing_throughput: Option<Throughput>,
}
}
impl UploadThroughputCheckFuture {
fn new(
response: HttpConnectorFuture,
time_source: SharedTimeSource,
sleep_impl: SharedAsyncSleep,
upload_throughput: UploadThroughput,
options: MinimumThroughputBodyOptions,
) -> Self {
let resolution = upload_throughput.resolution();
Self {
response,
check_interval: Some(sleep_impl.sleep(resolution)),
grace_period: None,
time_source,
sleep_impl,
upload_throughput,
resolution,
options,
failing_throughput: None,
}
}
}
impl Future for UploadThroughputCheckFuture {
type Output = Result<HttpResponse, ConnectorError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Poll::Ready(output) = this.response.poll(cx) {
return Poll::Ready(output);
} else {
let mut below_minimum_throughput = false;
let check_interval_expired = this
.check_interval
.as_mut()
.as_pin_mut()
.expect("always set")
.poll(cx)
.is_ready();
if check_interval_expired {
// Set up the next check interval
*this.check_interval = Some(this.sleep_impl.sleep(*this.resolution));
// Wake so that the check interval future gets polled
// next time this poll method is called. If it never gets polled,
// then this task won't be woken to check again.
cx.waker().wake_by_ref();
}
let should_check = check_interval_expired || this.grace_period.is_some();
if should_check {
let now = this.time_source.now();
let report = this.upload_throughput.report(now);
let (violated, current_throughput) =
report.minimum_throughput_violated(this.options.minimum_throughput());
below_minimum_throughput = violated;
if below_minimum_throughput && !this.failing_throughput.is_some() {
*this.failing_throughput = Some(current_throughput);
} else if !below_minimum_throughput {
*this.failing_throughput = None;
}
}
// If we kicked off a grace period and are now satisfied, clear out the grace period
if !below_minimum_throughput && this.grace_period.is_some() {
tracing::debug!("upload minimum throughput recovered during grace period");
*this.grace_period = None;
}
if below_minimum_throughput {
// Start a grace period if below minimum throughput
if this.grace_period.is_none() {
tracing::debug!(
grace_period=?this.options.grace_period(),
"upload minimum throughput below configured minimum; starting grace period"
);
*this.grace_period = Some(this.sleep_impl.sleep(this.options.grace_period()));
}
// Check the grace period if one is already set and we're not satisfied
if let Some(grace_period) = this.grace_period.as_pin_mut() {
if grace_period.poll(cx).is_ready() {
tracing::debug!("grace period ended; timing out request");
return Poll::Ready(Err(ConnectorError::timeout(
Error::ThroughputBelowMinimum {
expected: this.options.minimum_throughput(),
actual: this
.failing_throughput
.expect("always set if there's a grace period"),
}
.into(),
)));
}
}
}
}
Poll::Pending
}
}
pin_project_lite::pin_project! {
#[project = EnumProj]
pub(crate) enum MaybeUploadThroughputCheckFuture {
Direct { #[pin] future: HttpConnectorFuture },
Checked { #[pin] future: UploadThroughputCheckFuture },
}
}
impl MaybeUploadThroughputCheckFuture {
pub(crate) fn new(
cfg: &mut ConfigBag,
components: &RuntimeComponents,
connector_future: HttpConnectorFuture,
) -> Self {
if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
if sspcfg.is_enabled() {
let options = MinimumThroughputBodyOptions::from(sspcfg);
return Self::new_inner(
connector_future,
components.time_source(),
components.sleep_impl(),
cfg.interceptor_state().load::<UploadThroughput>().cloned(),
Some(options),
);
}
}
tracing::debug!("no minimum upload throughput checks");
Self::new_inner(connector_future, None, None, None, None)
}
fn new_inner(
response: HttpConnectorFuture,
time_source: Option<SharedTimeSource>,
sleep_impl: Option<SharedAsyncSleep>,
upload_throughput: Option<UploadThroughput>,
options: Option<MinimumThroughputBodyOptions>,
) -> Self {
match (time_source, sleep_impl, upload_throughput, options) {
(Some(time_source), Some(sleep_impl), Some(upload_throughput), Some(options)) => {
tracing::debug!(options=?options, "applying minimum upload throughput check future");
Self::Checked {
future: UploadThroughputCheckFuture::new(
response,
time_source,
sleep_impl,
upload_throughput,
options,
),
}
}
_ => Self::Direct { future: response },
}
}
}
impl Future for MaybeUploadThroughputCheckFuture {
type Output = Result<HttpResponse, ConnectorError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
EnumProj::Direct { future } => future.poll(cx),
EnumProj::Checked { future } => future.poll(cx),
}
}
}

View File

@ -3,14 +3,58 @@
* SPDX-License-Identifier: Apache-2.0
*/
use super::{BoxError, Error, MinimumThroughputBody};
use super::{BoxError, Error, MinimumThroughputDownloadBody};
use crate::client::http::body::minimum_throughput::{
throughput::ThroughputReport, Throughput, ThroughputReadingBody,
};
use aws_smithy_async::rt::sleep::AsyncSleep;
use http_body_0_4::Body;
use std::future::Future;
use std::pin::{pin, Pin};
use std::task::{Context, Poll};
impl<B> Body for MinimumThroughputBody<B>
const ZERO_THROUGHPUT: Throughput = Throughput::new_bytes_per_second(0);
// Helper trait for interpretting the throughput report.
trait DownloadReport {
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput);
}
impl DownloadReport for ThroughputReport {
fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) {
let throughput = match self {
// If the report is incomplete, then we don't have enough data yet to
// decide if minimum throughput was violated.
ThroughputReport::Incomplete => {
tracing::trace!(
"not enough data to decide if minimum throughput has been violated"
);
return (false, ZERO_THROUGHPUT);
}
// If no polling is taking place, then the user has stalled.
// In this case, we don't want to say minimum throughput was violated.
ThroughputReport::NoPolling => {
tracing::debug!(
"the user has stalled; this will not become a minimum throughput violation"
);
return (false, ZERO_THROUGHPUT);
}
// If we're stuck in Poll::Pending, then the server has stalled. Alternatively,
// if we're transferring data, but it's too slow, then we also want to say
// that the minimum throughput has been violated.
ThroughputReport::Pending => ZERO_THROUGHPUT,
ThroughputReport::Transferred(tp) => tp,
};
let violated = throughput < minimum_throughput;
if violated {
tracing::debug!(
"current throughput: {throughput} is below minimum: {minimum_throughput}"
);
}
(violated, throughput)
}
}
impl<B> Body for MinimumThroughputDownloadBody<B>
where
B: Body<Data = bytes::Bytes, Error = BoxError>,
{
@ -30,12 +74,13 @@ where
let poll_res = match this.inner.poll_data(cx) {
Poll::Ready(Some(Ok(bytes))) => {
tracing::trace!("received data: {}", bytes.len());
this.throughput_logs.push((now, bytes.len() as u64));
this.throughput_logs
.push_bytes_transferred(now, bytes.len() as u64);
Poll::Ready(Some(Ok(bytes)))
}
Poll::Pending => {
tracing::trace!("received poll pending");
this.throughput_logs.push((now, 0));
this.throughput_logs.push_pending(now);
Poll::Pending
}
// If we've read all the data or an error occurred, then return that result.
@ -46,44 +91,27 @@ where
let mut sleep_fut = this
.sleep_fut
.take()
.unwrap_or_else(|| this.async_sleep.sleep(this.options.check_interval()));
.unwrap_or_else(|| this.async_sleep.sleep(*this.resolution));
if let Poll::Ready(()) = pin!(&mut sleep_fut).poll(cx) {
tracing::trace!("sleep future triggered—triggering a wakeup");
// Whenever the sleep future expires, we replace it.
sleep_fut = this.async_sleep.sleep(this.options.check_interval());
sleep_fut = this.async_sleep.sleep(*this.resolution);
// We also schedule a wake up for current task to ensure that
// it gets polled at least one more time.
cx.waker().wake_by_ref();
};
this.sleep_fut.replace(sleep_fut);
let calculated_tpt = match this
.throughput_logs
.calculate_throughput(now, this.options.check_window())
{
Some(tpt) => tpt,
None => {
tracing::trace!("calculated throughput is None!");
return poll_res;
}
};
tracing::trace!(
"calculated throughput {:?} (window: {:?})",
calculated_tpt,
this.options.check_window()
);
// Calculate the current throughput and emit an error if it's too low and
// the grace period has elapsed.
let is_below_minimum_throughput = calculated_tpt <= this.options.minimum_throughput();
if is_below_minimum_throughput {
// Check the grace period future to see if it needs creating.
tracing::trace!(
in_grace_period = this.grace_period_fut.is_some(),
observed_throughput = ?calculated_tpt,
minimum_throughput = ?this.options.minimum_throughput(),
"below minimum throughput"
);
let report = this.throughput_logs.report(now);
let (violated, current_throughput) =
report.minimum_throughput_violated(this.options.minimum_throughput());
if violated {
if this.grace_period_fut.is_none() {
tracing::debug!("entering minimum throughput grace period");
}
let mut grace_period_fut = this
.grace_period_fut
.take()
@ -92,13 +120,16 @@ where
// The grace period has ended!
return Poll::Ready(Some(Err(Box::new(Error::ThroughputBelowMinimum {
expected: self.options.minimum_throughput(),
actual: calculated_tpt,
actual: current_throughput,
}))));
};
this.grace_period_fut.replace(grace_period_fut);
} else {
// Ensure we don't have an active grace period future if we're not
// currently below the minimum throughput.
if this.grace_period_fut.is_some() {
tracing::debug!("throughput recovered; exiting grace period");
}
let _ = this.grace_period_fut.take();
}
@ -112,290 +143,63 @@ where
let this = self.as_mut().project();
this.inner.poll_trailers(cx)
}
}
// These tests use `hyper::body::Body::wrap_stream`
#[cfg(all(test, feature = "connector-hyper-0-14-x", feature = "test-util"))]
mod test {
use super::{super::Throughput, Error, MinimumThroughputBody};
use crate::client::http::body::minimum_throughput::options::MinimumThroughputBodyOptions;
use crate::test_util::capture_test_logs::capture_test_logs;
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep, ManualTimeSource};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::byte_stream::{AggregatedBytes, ByteStream};
use aws_smithy_types::error::display::DisplayErrorContext;
use bytes::{BufMut, Bytes, BytesMut};
use http::HeaderMap;
use http_body_0_4::Body;
use once_cell::sync::Lazy;
use pretty_assertions::assert_eq;
use std::convert::Infallible;
use std::error::Error as StdError;
use std::future::{poll_fn, Future};
use std::pin::{pin, Pin};
use std::task::{Context, Poll};
use std::time::{Duration, UNIX_EPOCH};
struct NeverBody;
impl Body for NeverBody {
type Data = Bytes;
type Error = Box<(dyn StdError + Send + Sync + 'static)>;
fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
Poll::Pending
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
unreachable!("body can't be read, so this won't be called")
}
fn size_hint(&self) -> http_body_0_4::SizeHint {
self.inner.size_hint()
}
#[tokio::test()]
async fn test_self_waking() {
let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH);
let mut body = MinimumThroughputBody::new(
time_source.clone(),
async_sleep.clone(),
NeverBody,
Default::default(),
);
time_source.advance(Duration::from_secs(1));
let actual_err = body.data().await.expect("next chunk exists").unwrap_err();
let expected_err = Error::ThroughputBelowMinimum {
expected: (1, Duration::from_secs(1)).into(),
actual: (0, Duration::from_secs(1)).into(),
};
assert_eq!(expected_err.to_string(), actual_err.to_string());
}
fn create_test_stream(
async_sleep: impl AsyncSleep + Clone,
) -> impl futures_util::Stream<Item = Result<Bytes, Infallible>> {
futures_util::stream::unfold(1, move |state| {
let async_sleep = async_sleep.clone();
async move {
if state > 255 {
None
} else {
async_sleep.sleep(Duration::from_secs(1)).await;
Some((
Result::<_, Infallible>::Ok(Bytes::from_static(b"00000000")),
state + 1,
))
}
}
})
}
static EXPECTED_BYTES: Lazy<Vec<u8>> =
Lazy::new(|| (1..=255).flat_map(|_| b"00000000").copied().collect());
fn eight_byte_per_second_stream_with_minimum_throughput_timeout(
minimum_throughput: Throughput,
) -> (
impl Future<Output = Result<AggregatedBytes, aws_smithy_types::byte_stream::error::Error>>,
ManualTimeSource,
InstantSleep,
) {
let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH);
let time_clone = time_source.clone();
// Will send ~8 bytes per second.
let stream = create_test_stream(async_sleep.clone());
let body = ByteStream::new(SdkBody::from_body_0_4(hyper_0_14::body::Body::wrap_stream(
stream,
)));
let body = body.map(move |body| {
let time_source = time_clone.clone();
// We don't want to log these sleeps because it would duplicate
// the `sleep` calls being logged by the MTB
let async_sleep = InstantSleep::unlogged();
SdkBody::from_body_0_4(MinimumThroughputBody::new(
time_source,
async_sleep,
body,
MinimumThroughputBodyOptions::builder()
.minimum_throughput(minimum_throughput)
.build(),
))
});
(body.collect(), time_source, async_sleep)
}
async fn expect_error(minimum_throughput: Throughput) {
let (res, ..) =
eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput);
let expected_err = Error::ThroughputBelowMinimum {
expected: minimum_throughput,
actual: Throughput::new(8, Duration::from_secs(1)),
};
match res.await {
Ok(_) => {
panic!(
"response succeeded instead of returning the expected error '{expected_err}'"
)
}
Err(actual_err) => {
assert_eq!(
expected_err.to_string(),
// We need to source this so that we don't get the streaming error it's wrapped in.
actual_err.source().unwrap().to_string()
);
}
}
}
#[tokio::test]
async fn test_throughput_timeout_less_than() {
let minimum_throughput = Throughput::new_bytes_per_second(9);
expect_error(minimum_throughput).await;
}
async fn expect_success(minimum_throughput: Throughput) {
let (res, time_source, async_sleep) =
eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput);
match res.await {
Ok(res) => {
assert_eq!(255.0, time_source.seconds_since_unix_epoch());
assert_eq!(Duration::from_secs(255), async_sleep.total_duration());
assert_eq!(*EXPECTED_BYTES, res.to_vec());
}
Err(err) => panic!("{}", DisplayErrorContext(err.source().unwrap())),
}
}
#[tokio::test]
async fn test_throughput_timeout_equal_to() {
let (_guard, _) = capture_test_logs();
// a tiny bit less. To capture 0-throughput properly, we need to allow 0 to be 0
let minimum_throughput = Throughput::new(31, Duration::from_secs(4));
expect_success(minimum_throughput).await;
}
#[tokio::test]
async fn test_throughput_timeout_greater_than() {
let minimum_throughput = Throughput::new(20, Duration::from_secs(3));
expect_success(minimum_throughput).await;
}
// A multiplier for the sine wave amplitude; Chosen arbitrarily.
const BYTE_COUNT_UPPER_LIMIT: u64 = 1000;
/// emits 1000B/S for 5 seconds then suddenly stops
fn sudden_stop(
async_sleep: impl AsyncSleep + Clone,
) -> impl futures_util::Stream<Item = Result<Bytes, Infallible>> {
let sleep_dur = Duration::from_millis(50);
fastrand::seed(0);
futures_util::stream::unfold(1, move |i| {
let async_sleep = async_sleep.clone();
async move {
let number_seconds = (i * sleep_dur).as_secs_f64();
async_sleep.sleep(sleep_dur).await;
if number_seconds > 5.0 {
Some((Result::<Bytes, Infallible>::Ok(Bytes::new()), i + 1))
} else {
let mut bytes = BytesMut::new();
let bytes_per_segment =
(BYTE_COUNT_UPPER_LIMIT as f64) * sleep_dur.as_secs_f64();
for _ in 0..bytes_per_segment as usize {
bytes.put_u8(0)
}
Some((Result::<Bytes, Infallible>::Ok(bytes.into()), i + 1))
}
}
})
}
#[tokio::test]
async fn test_stalled_stream_detection() {
test_suddenly_stopping_stream(0, Duration::from_secs(6)).await
}
#[tokio::test]
async fn test_slow_stream_detection() {
test_suddenly_stopping_stream(BYTE_COUNT_UPPER_LIMIT / 2, Duration::from_secs_f64(5.50))
.await
}
#[tokio::test]
async fn test_check_interval() {
let (_guard, _) = capture_test_logs();
let (ts, sleep) = instant_time_and_sleep(UNIX_EPOCH);
let mut body = MinimumThroughputBody::new(
ts,
sleep.clone(),
NeverBody,
MinimumThroughputBodyOptions::builder()
.check_interval(Duration::from_millis(1234))
.grace_period(Duration::from_millis(456))
.build(),
);
let mut body = pin!(body);
let _ = poll_fn(|cx| body.as_mut().poll_data(cx)).await;
assert_eq!(
sleep.logs(),
vec![
// sleep, by second sleep we know we have no data, then the grace period
Duration::from_millis(1234),
Duration::from_millis(1234),
Duration::from_millis(456)
]
);
}
async fn test_suddenly_stopping_stream(throughput_limit: u64, time_until_timeout: Duration) {
let (_guard, _) = capture_test_logs();
let options = MinimumThroughputBodyOptions::builder()
// Minimum throughput per second will be approx. half of the BYTE_COUNT_UPPER_LIMIT.
.minimum_throughput(Throughput::new_bytes_per_second(throughput_limit))
.build();
let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH);
let time_clone = time_source.clone();
let stream = sudden_stop(async_sleep.clone());
let body = ByteStream::new(SdkBody::from_body_0_4(hyper_0_14::body::Body::wrap_stream(
stream,
)));
let res = body
.map(move |body| {
let time_source = time_clone.clone();
// We don't want to log these sleeps because it would duplicate
// the `sleep` calls being logged by the MTB
let async_sleep = InstantSleep::unlogged();
SdkBody::from_body_0_4(MinimumThroughputBody::new(
time_source,
async_sleep,
body,
options.clone(),
))
})
.collect();
match res.await {
Ok(_res) => {
panic!("stream should have timed out");
}
Err(err) => {
dbg!(err);
assert_eq!(
async_sleep.total_duration(),
time_until_timeout,
"With throughput limit {:?} expected timeout after {:?} (stream starts sending 0's at 5 seconds.",
throughput_limit, time_until_timeout
);
}
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
}
impl<B> Body for ThroughputReadingBody<B>
where
B: Body<Data = bytes::Bytes, Error = BoxError>,
{
type Data = bytes::Bytes;
type Error = BoxError;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
// this code is called quite frequently in production—one every millisecond or so when downloading
// a stream. However, SystemTime::now is on the order of nanoseconds
let now = self.time_source.now();
// Attempt to read the data from the inner body, then update the
// throughput logs.
let this = self.as_mut().project();
match this.inner.poll_data(cx) {
Poll::Ready(Some(Ok(bytes))) => {
tracing::trace!("received data: {}", bytes.len());
this.throughput
.push_bytes_transferred(now, bytes.len() as u64);
Poll::Ready(Some(Ok(bytes)))
}
Poll::Pending => {
tracing::trace!("received poll pending");
this.throughput.push_pending(now);
Poll::Pending
}
// If we've read all the data or an error occurred, then return that result.
res => res,
}
}
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
let this = self.as_mut().project();
this.inner.poll_trailers(cx)
}
fn size_hint(&self) -> http_body_0_4::SizeHint {
self.inner.size_hint()
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
}

View File

@ -12,6 +12,7 @@ use std::time::Duration;
pub struct MinimumThroughputBodyOptions {
/// The minimum throughput that is acceptable.
minimum_throughput: Throughput,
/// The 'grace period' after which the minimum throughput will be enforced.
///
/// If this is set to 0, the minimum throughput will be enforced immediately.
@ -24,9 +25,6 @@ pub struct MinimumThroughputBodyOptions {
/// stream-startup.
grace_period: Duration,
/// The interval at which the throughput is checked.
check_interval: Duration,
/// The period of time to consider when computing the throughput
///
/// This SHOULD be longer than the check interval, or stuck-streams may evade detection.
@ -44,7 +42,6 @@ impl MinimumThroughputBodyOptions {
MinimumThroughputBodyOptionsBuilder::new()
.minimum_throughput(self.minimum_throughput)
.grace_period(self.grace_period)
.check_interval(self.check_interval)
}
/// The throughput check grace period.
@ -65,12 +62,10 @@ impl MinimumThroughputBodyOptions {
self.check_window
}
/// The rate at which the throughput is checked.
///
/// The actual rate throughput is checked may be higher than this value,
/// but it will never be lower.
/// Not used. Always returns `Duration::from_millis(500)`.
#[deprecated(note = "No longer used. Always returns Duration::from_millis(500)")]
pub fn check_interval(&self) -> Duration {
self.check_interval
Duration::from_millis(500)
}
}
@ -79,7 +74,6 @@ impl Default for MinimumThroughputBodyOptions {
Self {
minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT,
grace_period: DEFAULT_GRACE_PERIOD,
check_interval: DEFAULT_CHECK_INTERVAL,
check_window: DEFAULT_CHECK_WINDOW,
}
}
@ -89,11 +83,10 @@ impl Default for MinimumThroughputBodyOptions {
#[derive(Debug, Default, Clone)]
pub struct MinimumThroughputBodyOptionsBuilder {
minimum_throughput: Option<Throughput>,
check_interval: Option<Duration>,
check_window: Option<Duration>,
grace_period: Option<Duration>,
}
const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(500);
const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0);
const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput {
bytes_read: 1,
@ -136,19 +129,30 @@ impl MinimumThroughputBodyOptionsBuilder {
self
}
/// Set the rate at which throughput is checked.
///
/// Defaults to 1 second.
pub fn check_interval(mut self, check_interval: Duration) -> Self {
self.set_check_interval(Some(check_interval));
/// No longer used. The check interval is now based on the check window (not currently configurable).
#[deprecated(
note = "No longer used. The check interval is now based on the check window (not currently configurable). Open an issue if you need to configure the check window."
)]
pub fn check_interval(self, _check_interval: Duration) -> Self {
self
}
/// Set the rate at which throughput is checked.
///
/// Defaults to 1 second.
pub fn set_check_interval(&mut self, check_interval: Option<Duration>) -> &mut Self {
self.check_interval = check_interval;
/// No longer used. The check interval is now based on the check window (not currently configurable).
#[deprecated(
note = "No longer used. The check interval is now based on the check window (not currently configurable). Open an issue if you need to configure the check window."
)]
pub fn set_check_interval(&mut self, _check_interval: Option<Duration>) -> &mut Self {
self
}
#[allow(unused)]
pub(crate) fn check_window(mut self, check_window: Duration) -> Self {
self.set_check_window(Some(check_window));
self
}
#[allow(unused)]
pub(crate) fn set_check_window(&mut self, check_window: Option<Duration>) -> &mut Self {
self.check_window = check_window;
self
}
@ -161,8 +165,7 @@ impl MinimumThroughputBodyOptionsBuilder {
minimum_throughput: self
.minimum_throughput
.unwrap_or(DEFAULT_MINIMUM_THROUGHPUT),
check_interval: self.check_interval.unwrap_or(DEFAULT_CHECK_INTERVAL),
check_window: DEFAULT_CHECK_WINDOW,
check_window: self.check_window.unwrap_or(DEFAULT_CHECK_WINDOW),
}
}
}
@ -172,7 +175,6 @@ impl From<StalledStreamProtectionConfig> for MinimumThroughputBodyOptions {
MinimumThroughputBodyOptions {
grace_period: value.grace_period(),
minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT,
check_interval: DEFAULT_CHECK_INTERVAL,
check_window: DEFAULT_CHECK_WINDOW,
}
}

View File

@ -3,12 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
use std::collections::VecDeque;
use std::fmt;
use std::time::{Duration, SystemTime};
/// Throughput representation for use when configuring [`super::MinimumThroughputBody`]
#[derive(Debug, Clone, Copy)]
#[cfg_attr(test, derive(Eq))]
pub struct Throughput {
pub(super) bytes_read: u64,
pub(super) per_time_elapsed: Duration,
@ -29,7 +29,7 @@ impl Throughput {
}
/// Create a new throughput in bytes per second.
pub fn new_bytes_per_second(bytes: u64) -> Self {
pub const fn new_bytes_per_second(bytes: u64) -> Self {
Self {
bytes_read: bytes,
per_time_elapsed: Duration::from_secs(1),
@ -37,7 +37,7 @@ impl Throughput {
}
/// Create a new throughput in kilobytes per second.
pub fn new_kilobytes_per_second(kilobytes: u64) -> Self {
pub const fn new_kilobytes_per_second(kilobytes: u64) -> Self {
Self {
bytes_read: kilobytes * 1000,
per_time_elapsed: Duration::from_secs(1),
@ -45,7 +45,7 @@ impl Throughput {
}
/// Create a new throughput in megabytes per second.
pub fn new_megabytes_per_second(megabytes: u64) -> Self {
pub const fn new_megabytes_per_second(megabytes: u64) -> Self {
Self {
bytes_read: megabytes * 1000 * 1000,
per_time_elapsed: Duration::from_secs(1),
@ -97,90 +97,288 @@ impl From<(u64, Duration)> for Throughput {
}
}
#[derive(Clone)]
/// Overall label for a given bin.
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
enum BinLabel {
// IMPORTANT: The order of these enums matters since it represents their priority:
// Pending > TransferredBytes > NoPolling > Empty
//
/// There is no data in this bin.
Empty,
/// No polling took place during this bin.
NoPolling,
/// This many bytes were transferred during this bin.
TransferredBytes,
/// The user/remote was not providing/consuming data fast enough during this bin.
///
/// The number is the number of bytes transferred, if this replaced TransferredBytes.
Pending,
}
/// Represents a bin (or a cell) in a linear grid that represents a small chunk of time.
#[derive(Copy, Clone, Debug)]
struct Bin {
label: BinLabel,
bytes: u64,
}
impl Bin {
const fn new(label: BinLabel, bytes: u64) -> Self {
Self { label, bytes }
}
const fn empty() -> Self {
Self::new(BinLabel::Empty, 0)
}
fn is_empty(&self) -> bool {
matches!(self.label, BinLabel::Empty)
}
fn merge(&mut self, other: Bin) -> &mut Self {
// Assign values based on this priority order (highest priority higher up):
// 1. Pending
// 2. TransferredBytes
// 3. NoPolling
// 4. Empty
self.label = if other.label > self.label {
other.label
} else {
self.label
};
self.bytes += other.bytes;
self
}
/// Number of bytes transferred during this bin
fn bytes(&self) -> u64 {
self.bytes
}
}
#[derive(Copy, Clone, Debug, Default)]
struct BinCounts {
/// Number of bins with no data.
empty: usize,
/// Number of "no polling" bins.
no_polling: usize,
/// Number of "bytes transferred" bins.
transferred: usize,
/// Number of "pending" bins.
pending: usize,
}
/// Underlying stack-allocated linear grid buffer for tracking
/// throughput events for [`ThroughputLogs`].
#[derive(Copy, Clone, Debug)]
struct LogBuffer<const N: usize> {
entries: [Bin; N],
// The length only needs to exist so that the `fill_gaps` function
// can differentiate between `Empty` due to there not having been enough
// time to establish a full buffer worth of data vs. `Empty` due to a
// polling gap. Once the length reaches N, it will never change again.
length: usize,
}
impl<const N: usize> LogBuffer<N> {
fn new() -> Self {
Self {
entries: [Bin::empty(); N],
length: 0,
}
}
/// Mutably returns the tail of the buffer.
///
/// ## Panics
///
/// The buffer MUST have at least one bin in it before this is called.
fn tail_mut(&mut self) -> &mut Bin {
debug_assert!(self.length > 0);
&mut self.entries[self.length - 1]
}
/// Pushes a bin into the buffer. If the buffer is already full,
/// then this will rotate the entire buffer to the left.
fn push(&mut self, bin: Bin) {
if self.filled() {
self.entries.rotate_left(1);
self.entries[N - 1] = bin;
} else {
self.entries[self.length] = bin;
self.length += 1;
}
}
/// Returns the total number of bytes transferred within the time window.
fn bytes_transferred(&self) -> u64 {
self.entries.iter().take(self.length).map(Bin::bytes).sum()
}
#[inline]
fn filled(&self) -> bool {
self.length == N
}
/// Fills in missing NoData entries.
///
/// We want NoData entries to represent when a future hasn't been polled.
/// Since the future is in charge of logging in the first place, the only
/// way we can know about these is by examining gaps in time.
fn fill_gaps(&mut self) {
for entry in self.entries.iter_mut().take(self.length) {
if entry.is_empty() {
*entry = Bin::new(BinLabel::NoPolling, 0);
}
}
}
/// Returns the counts of each bin type in the buffer.
fn counts(&self) -> BinCounts {
let mut counts = BinCounts::default();
for entry in &self.entries {
match entry.label {
BinLabel::Empty => counts.empty += 1,
BinLabel::NoPolling => counts.no_polling += 1,
BinLabel::TransferredBytes => counts.transferred += 1,
BinLabel::Pending => counts.pending += 1,
}
}
counts
}
}
/// Report/summary of all the events in a time window.
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub(crate) enum ThroughputReport {
/// Not enough data to draw any conclusions. This happens early in a request/response.
Incomplete,
/// The stream hasn't been polled for most of this time window.
NoPolling,
/// The stream has been waiting for most of the time window.
Pending,
/// The stream transferred this amount of throughput during the time window.
Transferred(Throughput),
}
const BIN_COUNT: usize = 10;
/// Log of throughput in a request or response stream.
///
/// Used to determine if a configured minimum throughput is being met or not
/// so that a request or response stream can be timed out in the event of a
/// stall.
///
/// Request/response streams push data transfer or pending events to this log
/// based on what's going on in their poll functions. The log tracks three kinds
/// of events despite only receiving two: the third is "no polling". The poll
/// functions cannot know when they're not being polled, so the log examines gaps
/// in the event history to know when no polling took place.
///
/// The event logging is simplified down to a linear grid consisting of 10 "bins",
/// with each bin representing 1/10th the total time window. When an event is pushed,
/// it is either merged into the current tail bin, or all the bins are rotated
/// left to create a new empty tail bin, and then it is merged into that one.
#[derive(Clone, Debug)]
pub(super) struct ThroughputLogs {
max_length: usize,
inner: VecDeque<(SystemTime, u64)>,
bytes_processed: u64,
resolution: Duration,
current_tail: SystemTime,
buffer: LogBuffer<BIN_COUNT>,
}
impl ThroughputLogs {
pub(super) fn new(max_length: usize) -> Self {
/// Creates a new log starting at `now` with the given `time_window`.
///
/// Note: the `time_window` gets divided by 10 to create smaller sub-windows
/// to track throughput. The time window should be configured to be large enough
/// so that these sub-windows aren't too small for network-based events.
/// A time window of 10ms probably won't work, but 500ms might. The default
/// is one second.
pub(super) fn new(time_window: Duration, now: SystemTime) -> Self {
assert!(!time_window.is_zero());
let resolution = time_window.div_f64(BIN_COUNT as f64);
Self {
inner: VecDeque::with_capacity(max_length),
max_length,
bytes_processed: 0,
resolution,
current_tail: now,
buffer: LogBuffer::new(),
}
}
pub(super) fn push(&mut self, throughput: (SystemTime, u64)) {
// When the number of logs exceeds the max length, toss the oldest log.
if self.inner.len() == self.max_length {
self.bytes_processed -= self.inner.pop_front().map(|(_, sz)| sz).unwrap_or_default();
}
debug_assert!(self.inner.capacity() > self.inner.len());
self.bytes_processed += throughput.1;
self.inner.push_back(throughput);
/// Returns the resolution at which events are logged at.
///
/// The resolution is the number of bins in the time window.
pub(super) fn resolution(&self) -> Duration {
self.resolution
}
fn buffer_full(&self) -> bool {
self.inner.len() == self.max_length
/// Pushes a "pending" event.
///
/// Pending indicates the streaming future is waiting for something.
/// In an upload, it is waiting for data from the user, and in a download,
/// it is waiting for data from the server.
pub(super) fn push_pending(&mut self, time: SystemTime) {
self.push(time, Bin::new(BinLabel::Pending, 0));
}
pub(super) fn calculate_throughput(
&self,
now: SystemTime,
time_window: Duration,
) -> Option<Throughput> {
// There are a lot of pathological cases that are 0 throughput. These cases largely shouldn't
// happen, because the check interval MUST be less than the check window
let total_length = self
.inner
.iter()
.last()?
.0
.duration_since(self.inner.front()?.0)
.ok()?;
// during a "healthy" request we'll only have a few milliseconds of logs (shorter than the check window)
if total_length < time_window {
// if we haven't hit our requested time window & the buffer still isn't full, then
// return `None` — this is the "startup grace period"
return if !self.buffer_full() {
None
} else {
// Otherwise, if the entire buffer fits in the timewindow, we can the shortcut to
// avoid recomputing all the data
Some(Throughput {
bytes_read: self.bytes_processed,
per_time_elapsed: total_length,
})
};
/// Pushes a data transferred event.
///
/// Indicates that this number of bytes were transferred at this time.
pub(super) fn push_bytes_transferred(&mut self, time: SystemTime, bytes: u64) {
self.push(time, Bin::new(BinLabel::TransferredBytes, bytes));
}
fn push(&mut self, now: SystemTime, value: Bin) {
self.catch_up(now);
self.buffer.tail_mut().merge(value);
self.buffer.fill_gaps();
}
/// Pushes empty bins until `current_tail` is caught up to `now`.
fn catch_up(&mut self, now: SystemTime) {
while now >= self.current_tail {
self.current_tail += self.resolution;
self.buffer.push(Bin::empty());
}
let minimum_ts = now - time_window;
let first_item = self.inner.iter().find(|(ts, _)| *ts >= minimum_ts)?.0;
assert!(self.current_tail >= now);
}
let time_elapsed = now.duration_since(first_item).unwrap_or_default();
/// Generates an overall report of the time window.
pub(super) fn report(&mut self, now: SystemTime) -> ThroughputReport {
self.catch_up(now);
self.buffer.fill_gaps();
let total_bytes_logged = self
.inner
.iter()
.rev()
.take_while(|(ts, _)| *ts > minimum_ts)
.map(|t| t.1)
.sum::<u64>();
let BinCounts {
empty,
no_polling,
transferred,
pending,
} = self.buffer.counts();
Some(Throughput {
bytes_read: total_bytes_logged,
per_time_elapsed: time_elapsed,
})
// If there are any empty cells at all, then we haven't been tracking
// long enough to make any judgements about the stream's progress.
if empty > 0 {
return ThroughputReport::Incomplete;
}
let bytes = self.buffer.bytes_transferred();
let time = self.resolution * (BIN_COUNT - empty) as u32;
let throughput = Throughput::new(bytes, time);
let half = BIN_COUNT / 2;
match (transferred > 0, no_polling >= half, pending >= half) {
(true, _, _) => ThroughputReport::Transferred(throughput),
(_, true, _) => ThroughputReport::NoPolling,
(_, _, true) => ThroughputReport::Pending,
_ => ThroughputReport::Incomplete,
}
}
}
#[cfg(test)]
mod test {
use super::{Throughput, ThroughputLogs};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::*;
use std::time::Duration;
#[test]
fn test_throughput_eq() {
@ -192,92 +390,146 @@ mod test {
assert_eq!(t2, t3);
}
fn build_throughput_log(
length: u32,
tick_duration: Duration,
rate: u64,
) -> (ThroughputLogs, SystemTime) {
let mut throughput_logs = ThroughputLogs::new(length as usize);
for i in 1..=length {
throughput_logs.push((UNIX_EPOCH + (tick_duration * i), rate));
}
assert_eq!(length as usize, throughput_logs.inner.len());
(throughput_logs, UNIX_EPOCH + (tick_duration * length))
#[test]
fn incomplete_no_entries() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
let report = logs.report(start);
assert_eq!(ThroughputReport::Incomplete, report);
}
const EPSILON: f64 = 0.001;
macro_rules! assert_delta {
($x:expr, $y:expr, $d:expr) => {
if !(($x as f64) - $y < $d || $y - ($x as f64) < $d) {
panic!();
#[test]
fn incomplete_with_entries() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
logs.push_pending(start);
let report = logs.report(start + Duration::from_millis(300));
assert_eq!(ThroughputReport::Incomplete, report);
}
#[test]
fn incomplete_with_transferred() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
logs.push_pending(start);
logs.push_bytes_transferred(start + Duration::from_millis(100), 10);
let report = logs.report(start + Duration::from_millis(300));
assert_eq!(ThroughputReport::Incomplete, report);
}
#[test]
fn push_pending_at_the_beginning_of_each_tick() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
let mut now = start;
for i in 1..=BIN_COUNT {
logs.push_pending(now);
now += logs.resolution();
assert_eq!(i, logs.buffer.counts().pending);
}
let report = dbg!(&mut logs).report(now);
assert_eq!(ThroughputReport::Pending, report);
}
#[test]
fn push_pending_at_the_end_of_each_tick() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
let mut now = start;
for i in 1..BIN_COUNT {
now += logs.resolution();
logs.push_pending(now);
assert_eq!(i, dbg!(&logs).buffer.counts().pending);
assert_eq!(0, logs.buffer.counts().transferred);
assert_eq!(1, logs.buffer.counts().no_polling);
}
// This should replace the initial "no polling" bin
now += logs.resolution();
logs.push_pending(now);
assert_eq!(0, logs.buffer.counts().no_polling);
let report = dbg!(&mut logs).report(now);
assert_eq!(ThroughputReport::Pending, report);
}
#[test]
fn push_transferred_at_the_beginning_of_each_tick() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
let mut now = start;
for i in 1..=BIN_COUNT {
logs.push_bytes_transferred(now, 10);
if i != BIN_COUNT {
now += logs.resolution();
}
};
}
#[test]
fn test_throughput_log_calculate_throughput_1() {
let (throughput_logs, now) = build_throughput_log(1000, Duration::from_secs(1), 1);
for dur in [10, 100, 100] {
let throughput = throughput_logs
.calculate_throughput(now, Duration::from_secs(dur))
.unwrap();
assert_eq!(1.0, throughput.bytes_per_second());
assert_eq!(i, logs.buffer.counts().transferred);
assert_eq!(0, logs.buffer.counts().pending);
assert_eq!(0, logs.buffer.counts().no_polling);
}
let throughput = throughput_logs
.calculate_throughput(now, Duration::from_secs_f64(101.5))
.unwrap();
assert_delta!(1, throughput.bytes_per_second(), EPSILON);
let report = dbg!(&mut logs).report(now);
assert_eq!(
ThroughputReport::Transferred(Throughput::new(100, Duration::from_secs(1))),
report
);
}
#[test]
fn test_throughput_log_calculate_throughput_2() {
let (throughput_logs, now) = build_throughput_log(1000, Duration::from_secs(5), 5);
fn no_polling() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
let report = logs.report(start + Duration::from_secs(2));
assert_eq!(ThroughputReport::NoPolling, report);
}
let throughput = throughput_logs
.calculate_throughput(now, Duration::from_secs(1000))
.unwrap();
assert_eq!(1.0, throughput.bytes_per_second());
// Transferred bytes MUST take priority over pending
#[test]
fn mixed_bag_mostly_pending() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
logs.push_bytes_transferred(start + Duration::from_millis(50), 10);
logs.push_pending(start + Duration::from_millis(150));
logs.push_pending(start + Duration::from_millis(250));
logs.push_bytes_transferred(start + Duration::from_millis(350), 10);
logs.push_pending(start + Duration::from_millis(450));
// skip 550
logs.push_pending(start + Duration::from_millis(650));
logs.push_pending(start + Duration::from_millis(750));
logs.push_pending(start + Duration::from_millis(850));
let report = logs.report(start + Duration::from_millis(999));
assert_eq!(
ThroughputReport::Transferred(Throughput::new_bytes_per_second(20)),
report
);
}
#[test]
fn test_throughput_log_calculate_throughput_3() {
let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(200), 1024);
fn mixed_bag_mostly_pending_no_transferred() {
let start = SystemTime::UNIX_EPOCH;
let mut logs = ThroughputLogs::new(Duration::from_secs(1), start);
let throughput = throughput_logs
.calculate_throughput(now, Duration::from_secs(5))
.unwrap();
let expected_throughput = 1024.0 * 5.0;
assert_eq!(expected_throughput, throughput.bytes_per_second());
}
logs.push_pending(start + Duration::from_millis(50));
logs.push_pending(start + Duration::from_millis(150));
logs.push_pending(start + Duration::from_millis(250));
// skip 350
logs.push_pending(start + Duration::from_millis(450));
// skip 550
logs.push_pending(start + Duration::from_millis(650));
logs.push_pending(start + Duration::from_millis(750));
logs.push_pending(start + Duration::from_millis(850));
#[test]
fn test_throughput_log_calculate_throughput_4() {
let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(100), 12);
let throughput = throughput_logs
.calculate_throughput(now, Duration::from_secs(1))
.unwrap();
let expected_throughput = 12.0 * 10.0;
assert_eq!(expected_throughput, throughput.bytes_per_second());
}
#[test]
fn test_throughput_followed_by_0() {
let tick = Duration::from_millis(100);
let (mut throughput_logs, now) = build_throughput_log(1000, tick, 12);
let throughput = throughput_logs
.calculate_throughput(now, Duration::from_secs(1))
.unwrap();
let expected_throughput = 12.0 * 10.0;
assert_eq!(expected_throughput, throughput.bytes_per_second());
throughput_logs.push((now + tick, 0));
let throughput = throughput_logs
.calculate_throughput(now + tick, Duration::from_secs(1))
.unwrap();
assert_eq!(108.0, throughput.bytes_per_second());
let report = logs.report(start + Duration::from_millis(999));
assert_eq!(ThroughputReport::Pending, report);
}
}

View File

@ -5,9 +5,12 @@
use self::auth::orchestrate_auth;
use crate::client::interceptors::Interceptors;
use crate::client::orchestrator::endpoints::orchestrate_endpoint;
use crate::client::orchestrator::http::{log_response_body, read_body};
use crate::client::timeout::{MaybeTimeout, MaybeTimeoutConfig, TimeoutKind};
use crate::client::{
http::body::minimum_throughput::MaybeUploadThroughputCheckFuture,
orchestrator::endpoints::orchestrate_endpoint,
};
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::http::{HttpClient, HttpConnector, HttpConnectorSettings};
@ -385,7 +388,12 @@ async fn try_attempt(
builder.build()
};
let connector = http_client.http_connector(&settings, runtime_components);
connector.call(request).await.map_err(OrchestratorError::connector)
let response_future = MaybeUploadThroughputCheckFuture::new(
cfg,
runtime_components,
connector.call(request),
);
response_future.await.map_err(OrchestratorError::connector)
});
trace!(response = ?response, "received response from service");
ctx.set_response(response);

View File

@ -12,7 +12,6 @@ use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::time::TimeSource;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
use aws_smithy_runtime_api::client::auth::{
AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
@ -35,6 +34,9 @@ use aws_smithy_runtime_api::client::ser_de::{
DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
};
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_runtime_api::{
box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
};
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::retry::RetryConfig;
use aws_smithy_types::timeout::TimeoutConfig;
@ -293,6 +295,15 @@ impl<I, O, E> OperationBuilder<I, O, E> {
self
}
/// Configures stalled stream protection with the given config.
pub fn stalled_stream_protection(
mut self,
stalled_stream_protection: StalledStreamProtectionConfig,
) -> Self {
self.config.store_put(stalled_stream_protection);
self
}
/// Configures the serializer for the builder.
pub fn serializer<I2>(
mut self,
@ -339,6 +350,28 @@ impl<I, O, E> OperationBuilder<I, O, E> {
}
}
/// Configures the a deserializer implementation for the builder.
pub fn deserializer_impl<O2, E2>(
mut self,
deserializer: impl DeserializeResponse + Send + Sync + 'static,
) -> OperationBuilder<I, O2, E2>
where
O2: fmt::Debug + Send + Sync + 'static,
E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
{
let deserializer: SharedResponseDeserializer = deserializer.into_shared();
self.config.store_put(deserializer);
OperationBuilder {
service_name: self.service_name,
operation_name: self.operation_name,
config: self.config,
runtime_components: self.runtime_components,
runtime_plugins: self.runtime_plugins,
_phantom: Default::default(),
}
}
/// Creates an `Operation` from the builder.
pub fn build(self) -> Operation<I, O, E> {
let service_name = self.service_name.expect("service_name required");

View File

@ -3,7 +3,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
use crate::client::http::body::minimum_throughput::MinimumThroughputBody;
use crate::client::http::body::minimum_throughput::{
options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody,
UploadThroughput,
};
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime_api::box_error::BoxError;
@ -18,14 +21,16 @@ use aws_smithy_types::config_bag::ConfigBag;
use std::mem;
/// Adds stalled stream protection when sending requests and/or receiving responses.
#[derive(Debug)]
pub struct StalledStreamProtectionInterceptor {
enable_for_request_body: bool,
enable_for_response_body: bool,
}
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct StalledStreamProtectionInterceptor;
/// Stalled stream protection can be enable for request bodies, response bodies,
/// or both.
#[deprecated(
since = "1.2.0",
note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag."
)]
pub enum StalledStreamProtectionInterceptorKind {
/// Enable stalled stream protection for request bodies.
RequestBody,
@ -37,18 +42,13 @@ pub enum StalledStreamProtectionInterceptorKind {
impl StalledStreamProtectionInterceptor {
/// Create a new stalled stream protection interceptor.
pub fn new(kind: StalledStreamProtectionInterceptorKind) -> Self {
use StalledStreamProtectionInterceptorKind::*;
let (enable_for_request_body, enable_for_response_body) = match kind {
RequestBody => (true, false),
ResponseBody => (false, true),
RequestAndResponseBody => (true, true),
};
Self {
enable_for_request_body,
enable_for_response_body,
}
#[deprecated(
since = "1.2.0",
note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default."
)]
#[allow(deprecated)]
pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self {
Default::default()
}
}
@ -63,19 +63,26 @@ impl Intercept for StalledStreamProtectionInterceptor {
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if self.enable_for_request_body {
if let Some(cfg) = cfg.load::<StalledStreamProtectionConfig>() {
if cfg.is_enabled() {
let (async_sleep, time_source) =
get_runtime_component_deps(runtime_components)?;
tracing::trace!("adding stalled stream protection to request body");
add_stalled_stream_protection_to_body(
context.request_mut().body_mut(),
cfg,
async_sleep,
if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
if sspcfg.upload_enabled() {
let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
let now = time_source.now();
let options: MinimumThroughputBodyOptions = sspcfg.into();
let throughput = UploadThroughput::new(options.check_window(), now);
cfg.interceptor_state().store_put(throughput.clone());
tracing::trace!("adding stalled stream protection to request body");
let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken());
let it = it.map_preserve_contents(move |body| {
let time_source = time_source.clone();
SdkBody::from_body_0_4(ThroughputReadingBody::new(
time_source,
);
}
throughput.clone(),
body,
))
});
let _ = mem::replace(context.request_mut().body_mut(), it);
}
}
@ -88,19 +95,25 @@ impl Intercept for StalledStreamProtectionInterceptor {
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if self.enable_for_response_body {
if let Some(cfg) = cfg.load::<StalledStreamProtectionConfig>() {
if cfg.is_enabled() {
let (async_sleep, time_source) =
get_runtime_component_deps(runtime_components)?;
tracing::trace!("adding stalled stream protection to response body");
add_stalled_stream_protection_to_body(
context.response_mut().body_mut(),
cfg,
async_sleep,
if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>() {
if sspcfg.download_enabled() {
let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
tracing::trace!("adding stalled stream protection to response body");
let sspcfg = sspcfg.clone();
let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken());
let it = it.map_preserve_contents(move |body| {
let sspcfg = sspcfg.clone();
let async_sleep = async_sleep.clone();
let time_source = time_source.clone();
let mtb = MinimumThroughputDownloadBody::new(
time_source,
async_sleep,
body,
sspcfg.into(),
);
}
SdkBody::from_body_0_4(mtb)
});
let _ = mem::replace(context.response_mut().body_mut(), it);
}
}
Ok(())
@ -118,21 +131,3 @@ fn get_runtime_component_deps(
.ok_or("A time source is required when stalled stream protection is enabled")?;
Ok((async_sleep, time_source))
}
fn add_stalled_stream_protection_to_body(
body: &mut SdkBody,
cfg: &StalledStreamProtectionConfig,
async_sleep: SharedAsyncSleep,
time_source: SharedTimeSource,
) {
let cfg = cfg.clone();
let it = mem::replace(body, SdkBody::taken());
let it = it.map_preserve_contents(move |body| {
let cfg = cfg.clone();
let async_sleep = async_sleep.clone();
let time_source = time_source.clone();
let mtb = MinimumThroughputBody::new(time_source, async_sleep, body, cfg.into());
SdkBody::from_body_0_4(mtb)
});
let _ = mem::replace(body, it);
}

View File

@ -14,6 +14,29 @@ use tracing_subscriber::fmt::TestWriter;
#[derive(Debug)]
pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard);
/// Enables output of test logs to stdout at trace level by default.
///
/// The env filter can be changed with the `RUST_LOG` environment variable.
#[must_use]
pub fn show_test_logs() -> LogCaptureGuard {
let (mut writer, _rx) = Tee::stdout();
writer.loud();
let env_var = env::var("RUST_LOG").ok();
let env_filter = env_var.as_deref().unwrap_or("trace");
eprintln!(
"Enabled verbose test logging with env filter {env_filter:?}. \
You can change the env filter with the RUST_LOG environment variable."
);
let subscriber = tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_writer(Mutex::new(writer))
.finish();
let guard = tracing::subscriber::set_default(subscriber);
LogCaptureGuard(guard)
}
/// Capture logs from this test.
///
/// The logs will be captured until the `DefaultGuard` is dropped.

View File

@ -0,0 +1,113 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#![cfg(all(feature = "client", feature = "test-util"))]
pub use aws_smithy_async::{
test_util::tick_advance_sleep::{
tick_advance_time_and_sleep, TickAdvanceSleep, TickAdvanceTime,
},
time::TimeSource,
};
pub use aws_smithy_runtime::{
assert_str_contains,
client::{
orchestrator::operation::Operation,
stalled_stream_protection::StalledStreamProtectionInterceptor,
},
test_util::capture_test_logs::show_test_logs,
};
pub use aws_smithy_runtime_api::{
box_error::BoxError,
client::{
http::{
HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings,
SharedHttpConnector,
},
interceptors::context::{Error, Output},
orchestrator::{HttpRequest, HttpResponse, OrchestratorError},
result::SdkError,
runtime_components::RuntimeComponents,
ser_de::DeserializeResponse,
stalled_stream_protection::StalledStreamProtectionConfig,
},
http::{Response, StatusCode},
shared::IntoShared,
};
pub use aws_smithy_types::{
body::SdkBody, error::display::DisplayErrorContext, timeout::TimeoutConfig,
};
pub use bytes::Bytes;
pub use http_body_0_4::Body;
pub use pin_utils::pin_mut;
pub use std::{
collections::VecDeque,
convert::Infallible,
future::poll_fn,
mem,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
pub use tracing::{info, Instrument as _};
/// No really, it's 42 bytes long... super neat
pub const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data");
/// Ticks time forward by the given duration, and logs the current time for debugging.
#[macro_export]
macro_rules! tick {
($ticker:ident, $duration:expr) => {
$ticker.tick($duration).await;
let now = $ticker
.now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap();
tracing::info!("ticked {:?}, now at {:?}", $duration, now);
};
}
#[derive(Debug)]
pub struct FakeServer(pub SharedHttpConnector);
impl HttpClient for FakeServer {
fn http_connector(
&self,
_settings: &HttpConnectorSettings,
_components: &RuntimeComponents,
) -> SharedHttpConnector {
self.0.clone()
}
}
struct ChannelBody {
receiver: tokio::sync::mpsc::Receiver<Bytes>,
}
impl http_body_0_4::Body for ChannelBody {
type Data = Bytes;
type Error = Infallible;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match self.receiver.poll_recv(cx) {
Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))),
Poll::Pending => Poll::Pending,
}
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
unreachable!()
}
}
pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender<Bytes>) {
let (sender, receiver) = tokio::sync::mpsc::channel(1000);
(SdkBody::from_body_0_4(ChannelBody { receiver }), sender)
}

View File

@ -0,0 +1,297 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#![cfg(all(feature = "client", feature = "test-util"))]
use std::time::Duration;
#[macro_use]
mod stalled_stream_common;
use stalled_stream_common::*;
/// Scenario: Successfully download at a rate above the minimum throughput.
/// Expected: MUST NOT timeout.
#[tokio::test]
async fn download_success() {
let _logs = show_test_logs();
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let server = tokio::spawn(async move {
for _ in 1..100 {
response_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
drop(response_sender);
tick!(time, Duration::from_secs(1));
});
let response_body = op.invoke(()).await.expect("initial success");
let result = eagerly_consume(response_body).await;
server.await.unwrap();
result.ok().expect("response MUST NOT timeout");
}
/// Scenario: Download takes a some time to start, but then goes normally.
/// Expected: MUT NOT timeout.
#[tokio::test]
async fn download_slow_start() {
let _logs = show_test_logs();
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let server = tokio::spawn(async move {
// Delay almost to the end of the grace period before sending anything
tick!(time, Duration::from_secs(4));
for _ in 1..100 {
response_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
drop(response_sender);
tick!(time, Duration::from_secs(1));
});
let response_body = op.invoke(()).await.expect("initial success");
let result = eagerly_consume(response_body).await;
server.await.unwrap();
result.ok().expect("response MUST NOT timeout");
}
/// Scenario: Download starts fine, and then slowly falls below minimum throughput.
/// Expected: MUST timeout.
#[tokio::test]
async fn download_too_slow() {
let _logs = show_test_logs();
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let server = tokio::spawn(async move {
// Get slower with every poll
for delay in 1..100 {
let _ = response_sender.send(NEAT_DATA).await;
tick!(time, Duration::from_secs(delay));
}
drop(response_sender);
tick!(time, Duration::from_secs(1));
});
let response_body = op.invoke(()).await.expect("initial success");
let result = eagerly_consume(response_body).await;
server.await.unwrap();
let err = result.expect_err("should have timed out");
assert_str_contains!(
DisplayErrorContext(err.as_ref()).to_string(),
"minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed"
);
}
/// Scenario: Download starts fine, and then the server stalls and stops sending data.
/// Expected: MUST timeout.
#[tokio::test]
async fn download_stalls() {
let _logs = show_test_logs();
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let server = tokio::spawn(async move {
for _ in 1..10 {
response_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
tick!(time, Duration::from_secs(10));
});
let response_body = op.invoke(()).await.expect("initial success");
let result = tokio::spawn(eagerly_consume(response_body));
server.await.unwrap();
let err = result
.await
.expect("no panics")
.expect_err("should have timed out");
assert_str_contains!(
DisplayErrorContext(err.as_ref()).to_string(),
"minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed"
);
}
/// Scenario: Download starts fine, but then the server stalls for a time within the
/// grace period. Following that, it starts sending data again.
/// Expected: MUST NOT timeout.
#[tokio::test]
async fn download_stall_recovery_in_grace_period() {
let _logs = show_test_logs();
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let server = tokio::spawn(async move {
for _ in 1..10 {
response_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
// Delay almost to the end of the grace period
tick!(time, Duration::from_secs(4));
// And now recover
for _ in 1..10 {
response_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
drop(response_sender);
tick!(time, Duration::from_secs(1));
});
let response_body = op.invoke(()).await.expect("initial success");
let result = eagerly_consume(response_body).await;
server.await.unwrap();
result.ok().expect("response MUST NOT timeout");
}
/// Scenario: The server sends data fast enough, but the customer doesn't consume the
/// data fast enough.
/// Expected: MUST NOT timeout.
#[tokio::test]
async fn user_downloads_data_too_slowly() {
let _logs = show_test_logs();
let (time, sleep) = tick_advance_time_and_sleep();
let (server, response_sender) = channel_server();
let op = operation(server, time.clone(), sleep);
let server = tokio::spawn(async move {
for _ in 1..100 {
response_sender.send(NEAT_DATA).await.unwrap();
}
drop(response_sender);
});
let response_body = op.invoke(()).await.expect("initial success");
let result = slowly_consume(time, response_body).await;
server.await.unwrap();
result.ok().expect("response MUST NOT timeout");
}
use download_test_tools::*;
mod download_test_tools {
use crate::stalled_stream_common::*;
fn response(body: SdkBody) -> HttpResponse {
HttpResponse::try_from(http::Response::builder().status(200).body(body).unwrap()).unwrap()
}
pub fn operation(
http_connector: impl HttpConnector + 'static,
time: TickAdvanceTime,
sleep: TickAdvanceSleep,
) -> Operation<(), SdkBody, Infallible> {
#[derive(Debug)]
struct Deserializer;
impl DeserializeResponse for Deserializer {
fn deserialize_streaming(
&self,
response: &mut HttpResponse,
) -> Option<Result<Output, OrchestratorError<Error>>> {
let mut body = SdkBody::taken();
mem::swap(response.body_mut(), &mut body);
Some(Ok(Output::erase(body)))
}
fn deserialize_nonstreaming(
&self,
_: &HttpResponse,
) -> Result<Output, OrchestratorError<Error>> {
unreachable!()
}
}
let operation = Operation::builder()
.service_name("test")
.operation_name("test")
.http_client(FakeServer(http_connector.into_shared()))
.endpoint_url("http://localhost:1234/doesntmatter")
.no_auth()
.no_retry()
.timeout_config(TimeoutConfig::disabled())
.serializer(|_body: ()| Ok(HttpRequest::new(SdkBody::empty())))
.deserializer_impl(Deserializer)
.stalled_stream_protection(
StalledStreamProtectionConfig::enabled()
.grace_period(Duration::from_secs(5))
.build(),
)
.interceptor(StalledStreamProtectionInterceptor::default())
.sleep_impl(sleep)
.time_source(time)
.build();
operation
}
/// Fake server/connector that responds with a channel body.
pub fn channel_server() -> (SharedHttpConnector, tokio::sync::mpsc::Sender<Bytes>) {
#[derive(Debug)]
struct FakeServerConnector {
body: Arc<Mutex<Option<SdkBody>>>,
}
impl HttpConnector for FakeServerConnector {
fn call(&self, _request: HttpRequest) -> HttpConnectorFuture {
let body = self.body.lock().unwrap().take().unwrap();
HttpConnectorFuture::new(async move { Ok(response(body)) })
}
}
let (body, body_sender) = channel_body();
(
FakeServerConnector {
body: Arc::new(Mutex::new(Some(body))),
}
.into_shared(),
body_sender,
)
}
/// Simulate a client eagerly consuming all the data sent to it from the server.
pub async fn eagerly_consume(body: SdkBody) -> Result<(), BoxError> {
pin_mut!(body);
while let Some(result) = poll_fn(|cx| body.as_mut().poll_data(cx)).await {
if let Err(err) = result {
return Err(err);
} else {
tracing::info!("consumed bytes from the response body");
}
}
Ok(())
}
/// Simulate a client very slowly consuming data with an eager server.
///
/// This implementation will take longer than the grace period to consume
/// the next piece of data.
pub async fn slowly_consume(time: TickAdvanceTime, body: SdkBody) -> Result<(), BoxError> {
pin_mut!(body);
while let Some(result) = poll_fn(|cx| body.as_mut().poll_data(cx)).await {
if let Err(err) = result {
return Err(err);
} else {
tracing::info!("consumed bytes from the response body");
tick!(time, Duration::from_secs(10));
}
}
Ok(())
}
}

View File

@ -7,7 +7,7 @@
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_async::time::{SystemTimeSource, TimeSource};
use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputBody;
use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputDownloadBody;
use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::byte_stream::ByteStream;
@ -92,7 +92,7 @@ async fn make_request(address: &str, wrap_body: bool) -> Duration {
let time_source = SystemTimeSource::new();
let sleep = TokioSleep::new();
let opts = StalledStreamProtectionConfig::enabled().build();
let mtb = MinimumThroughputBody::new(time_source, sleep, body, opts.into());
let mtb = MinimumThroughputDownloadBody::new(time_source, sleep, body, opts.into());
SdkBody::from_body_0_4(mtb)
});
}

View File

@ -0,0 +1,342 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#![cfg(all(feature = "client", feature = "test-util"))]
#[macro_use]
mod stalled_stream_common;
use stalled_stream_common::*;
/// Scenario: Successful upload at a rate above the minimum throughput.
/// Expected: MUST NOT timeout.
#[tokio::test]
async fn upload_success() {
let _logs = show_test_logs();
let (server, time, sleep) = eager_server(true);
let op = operation(server, time, sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
for _ in 0..100 {
body_sender.send(NEAT_DATA).await.unwrap();
}
drop(body_sender);
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
/// Scenario: Upload takes some time to start, but then goes normally.
/// Expected: MUST NOT timeout.
#[tokio::test]
async fn upload_slow_start() {
let _logs = show_test_logs();
let (server, time, sleep) = eager_server(false);
let op = operation(server, time.clone(), sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
// Advance longer than the grace period. This shouldn't fail since
// it is the customer's side that hasn't produced data yet, not a server issue.
time.tick(Duration::from_secs(10)).await;
for _ in 0..100 {
body_sender.send(NEAT_DATA).await.unwrap();
time.tick(Duration::from_secs(1)).await;
}
drop(body_sender);
time.tick(Duration::from_secs(1)).await;
});
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
/// Scenario: The upload is going fine, but falls below the minimum throughput.
/// Expected: MUST timeout.
#[tokio::test]
async fn upload_too_slow() {
let _logs = show_test_logs();
// Server that starts off fast enough, but gets slower over time until it should timeout.
let (server, time, sleep) = time_sequence_server([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let op = operation(server, time, sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
for send in 0..100 {
info!("send {send}");
body_sender.send(NEAT_DATA).await.unwrap();
}
drop(body_sender);
});
expect_timeout(result.await.expect("no panics"));
}
/// Scenario: The server stops asking for data, the client maxes out its send buffer,
/// and the request stream stops being polled.
/// Expected: MUST timeout after the grace period completes.
#[tokio::test]
async fn upload_stalls() {
let _logs = show_test_logs();
let (server, time, sleep) = stalling_server();
let op = operation(server, time.clone(), sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
for send in 1..=100 {
info!("send {send}");
body_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
}
drop(body_sender);
time.tick(Duration::from_secs(1)).await;
});
expect_timeout(result.await.expect("no panics"));
}
/// Scenario: All the request data is either uploaded to the server or buffered in the
/// HTTP client, but the response doesn't start coming through within the grace period.
/// Expected: MUST timeout after the grace period completes.
#[tokio::test]
async fn complete_upload_no_response() {
let _logs = show_test_logs();
let (server, time, sleep) = stalling_server();
let op = operation(server, time.clone(), sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
body_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
drop(body_sender);
time.tick(Duration::from_secs(6)).await;
});
expect_timeout(result.await.expect("no panics"));
}
// Scenario: The server stops asking for data, the client maxes out its send buffer,
// and the request stream stops being polled. However, before the grace period
// is over, the server recovers and starts asking for data again.
// Expected: MUST NOT timeout.
#[tokio::test]
async fn upload_stall_recovery_in_grace_period() {
let _logs = show_test_logs();
// Server starts off fast enough, but then slows down almost up to
// the grace period, and then recovers.
let (server, time, sleep) = time_sequence_server([1, 4, 1]);
let op = operation(server, time, sleep);
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
for send in 0..100 {
info!("send {send}");
body_sender.send(NEAT_DATA).await.unwrap();
}
drop(body_sender);
});
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
// Scenario: The customer isn't providing data on the stream fast enough to satisfy
// the minimum throughput. This shouldn't be considered a stall since the
// server is asking for more data and could handle it if it were available.
// Expected: MUST NOT timeout.
#[tokio::test]
async fn user_provides_data_too_slowly() {
let _logs = show_test_logs();
let (server, time, sleep) = eager_server(false);
let op = operation(server, time.clone(), sleep.clone());
let (body, body_sender) = channel_body();
let result = tokio::spawn(async move { op.invoke(body).await });
let _streamer = tokio::spawn(async move {
body_sender.send(NEAT_DATA).await.unwrap();
tick!(time, Duration::from_secs(1));
body_sender.send(NEAT_DATA).await.unwrap();
// Now advance 10 seconds before sending more data, simulating a
// customer taking time to produce more data to stream.
tick!(time, Duration::from_secs(10));
body_sender.send(NEAT_DATA).await.unwrap();
drop(body_sender);
tick!(time, Duration::from_secs(1));
});
assert_eq!(200, result.await.unwrap().expect("success").as_u16());
}
use upload_test_tools::*;
mod upload_test_tools {
use crate::stalled_stream_common::*;
pub fn successful_response() -> HttpResponse {
HttpResponse::try_from(
http::Response::builder()
.status(200)
.body(SdkBody::empty())
.unwrap(),
)
.unwrap()
}
pub fn operation(
http_connector: impl HttpConnector + 'static,
time: TickAdvanceTime,
sleep: TickAdvanceSleep,
) -> Operation<SdkBody, StatusCode, Infallible> {
let operation = Operation::builder()
.service_name("test")
.operation_name("test")
.http_client(FakeServer(http_connector.into_shared()))
.endpoint_url("http://localhost:1234/doesntmatter")
.no_auth()
.no_retry()
.timeout_config(TimeoutConfig::disabled())
.serializer(|body: SdkBody| Ok(HttpRequest::new(body)))
.deserializer::<_, Infallible>(|response| Ok(response.status()))
.stalled_stream_protection(
StalledStreamProtectionConfig::enabled()
.grace_period(Duration::from_secs(5))
.build(),
)
.interceptor(StalledStreamProtectionInterceptor::default())
.sleep_impl(sleep)
.time_source(time)
.build();
operation
}
/// Creates a fake HttpConnector implementation that calls the given async $body_fn
/// to get the response body. This $body_fn is given a request body, time, and sleep.
macro_rules! fake_server {
($name:ident, $body_fn:expr) => {
fake_server!($name, $body_fn, (), ())
};
($name:ident, $body_fn:expr, $params_ty:ty, $params:expr) => {{
#[derive(Debug)]
struct $name(TickAdvanceTime, TickAdvanceSleep, $params_ty);
impl HttpConnector for $name {
fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
let time = self.0.clone();
let sleep = self.1.clone();
let params = self.2.clone();
let span = tracing::span!(tracing::Level::INFO, "FAKE SERVER");
HttpConnectorFuture::new(
async move {
let mut body = SdkBody::taken();
mem::swap(request.body_mut(), &mut body);
pin_mut!(body);
Ok($body_fn(body, time, sleep, params).await)
}
.instrument(span),
)
}
}
let (time, sleep) = tick_advance_time_and_sleep();
(
$name(time.clone(), sleep.clone(), $params).into_shared(),
time,
sleep,
)
}};
}
/// Fake server/connector that immediately reads all incoming data with an
/// optional 1 second gap in between polls.
pub fn eager_server(
advance_time: bool,
) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
async fn fake_server(
mut body: Pin<&mut SdkBody>,
time: TickAdvanceTime,
_: TickAdvanceSleep,
advance_time: bool,
) -> HttpResponse {
while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() {
if advance_time {
tick!(time, Duration::from_secs(1));
}
}
successful_response()
}
fake_server!(FakeServerConnector, fake_server, bool, advance_time)
}
/// Fake server/connector that reads some data, and then stalls.
pub fn stalling_server() -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
async fn fake_server(
mut body: Pin<&mut SdkBody>,
_time: TickAdvanceTime,
_sleep: TickAdvanceSleep,
_: (),
) -> HttpResponse {
let mut times = 5;
while times > 0 && poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() {
times -= 1;
}
// never awake after this
tracing::info!("stalling indefinitely");
std::future::pending::<()>().await;
unreachable!()
}
fake_server!(FakeServerConnector, fake_server)
}
/// Fake server/connector that polls data after each period of time in the given
/// sequence. Once the sequence completes, it will delay 1 second after each poll.
pub fn time_sequence_server(
time_sequence: impl IntoIterator<Item = u64>,
) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) {
async fn fake_server(
mut body: Pin<&mut SdkBody>,
time: TickAdvanceTime,
_sleep: TickAdvanceSleep,
time_sequence: Vec<u64>,
) -> HttpResponse {
let mut time_sequence: VecDeque<Duration> =
time_sequence.into_iter().map(Duration::from_secs).collect();
while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() {
let next_time = time_sequence.pop_front().unwrap_or(Duration::from_secs(1));
tick!(time, next_time);
}
successful_response()
}
fake_server!(
FakeServerConnector,
fake_server,
Vec<u64>,
time_sequence.into_iter().collect()
)
}
pub fn expect_timeout(result: Result<StatusCode, SdkError<Infallible, Response<SdkBody>>>) {
let err = result.expect_err("should have timed out");
assert_str_contains!(
DisplayErrorContext(&err).to_string(),
"minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed"
);
}
}