Update standard orchestrator retries with token bucket and more tests (#2764)

## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
addresses #2743 

## Description
<!--- Describe your changes in detail -->
- add more standard retry tests
- add optional standard retries token bucket

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
tests are included

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: John DiSanti <jdisanti@amazon.com>
This commit is contained in:
Zelda Hessler 2023-06-13 18:49:43 -05:00 committed by GitHub
parent 5473192d3f
commit 312d190535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 656 additions and 456 deletions

View File

@ -465,7 +465,6 @@ async fn build_provider_chain(
#[cfg(test)]
mod test {
use crate::profile::credentials::Builder;
use crate::test_case::TestEnvironment;

View File

@ -146,20 +146,6 @@ open class OperationGenerator(
if (codegenContext.smithyRuntimeMode.generateOrchestrator) {
rustTemplate(
"""
pub(crate) fn register_runtime_plugins(
runtime_plugins: #{RuntimePlugins},
handle: #{Arc}<crate::client::Handle>,
config_override: #{Option}<crate::config::Builder>,
) -> #{RuntimePlugins} {
#{register_default_runtime_plugins}(
runtime_plugins,
#{Box}::new(Self::new()) as _,
handle,
config_override
)
#{additional_runtime_plugins}
}
pub(crate) async fn orchestrate(
runtime_plugins: &#{RuntimePlugins},
input: #{Input},
@ -186,6 +172,20 @@ open class OperationGenerator(
let input = #{TypedBox}::new(input).erase();
#{invoke_with_stop_point}(input, runtime_plugins, stop_point).await
}
pub(crate) fn register_runtime_plugins(
runtime_plugins: #{RuntimePlugins},
handle: #{Arc}<crate::client::Handle>,
config_override: #{Option}<crate::config::Builder>,
) -> #{RuntimePlugins} {
#{register_default_runtime_plugins}(
runtime_plugins,
#{Box}::new(Self::new()) as _,
handle,
config_override
)
#{additional_runtime_plugins}
}
""",
*codegenScope,
"Error" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Error"),

View File

@ -105,6 +105,8 @@ class ServiceRuntimePluginGenerator(
"StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
"default_connector" to client.resolve("conns::default_connector"),
"require_connector" to client.resolve("conns::require_connector"),
"TimeoutConfig" to smithyTypes.resolve("timeout::TimeoutConfig"),
"RetryConfig" to smithyTypes.resolve("retry::RetryConfig"),
)
}
@ -142,20 +144,17 @@ class ServiceRuntimePluginGenerator(
self.handle.conf.endpoint_resolver());
cfg.set_endpoint_resolver(endpoint_resolver);
// TODO(enableNewSmithyRuntime): Use the `store_append` method of ConfigBag to insert classifiers
let retry_classifiers = #{RetryClassifiers}::new()
#{retry_classifier_customizations};
cfg.set_retry_classifiers(retry_classifiers);
// TODO(enableNewSmithyRuntime): Make it possible to set retry classifiers at the service level.
// Retry classifiers can also be set at the operation level and those should be added to the
// list of classifiers defined here, rather than replacing them.
let sleep_impl = self.handle.conf.sleep_impl();
let timeout_config = self.handle.conf.timeout_config();
let retry_config = self.handle.conf.retry_config();
let timeout_config = self.handle.conf.timeout_config().cloned().unwrap_or_else(|| #{TimeoutConfig}::disabled());
let retry_config = self.handle.conf.retry_config().cloned().unwrap_or_else(|| #{RetryConfig}::disabled());
if let Some(retry_config) = retry_config {
cfg.set_retry_strategy(#{StandardRetryStrategy}::new(retry_config));
}
cfg.set_retry_strategy(#{StandardRetryStrategy}::new(&retry_config));
let connector_settings = timeout_config.map(#{ConnectorSettings}::from_timeout_config).unwrap_or_default();
let connector_settings = #{ConnectorSettings}::from_timeout_config(&timeout_config);
if let Some(connection) = self.handle.conf.http_connector()
.and_then(|c| c.connector(&connector_settings, sleep_impl.clone()))
.or_else(|| #{default_connector}(&connector_settings, sleep_impl)) {
@ -180,9 +179,6 @@ class ServiceRuntimePluginGenerator(
"http_auth_scheme_customizations" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.HttpAuthScheme("cfg"))
},
"retry_classifier_customizations" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.RetryClassifier("cfg"))
},
"additional_config" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg"))
},

View File

@ -18,6 +18,12 @@ pub enum OrchestratorError<E> {
Interceptor { err: InterceptorError },
/// An error returned by a service.
Operation { err: E },
/// An error that occurs when a request times out.
Timeout { err: BoxError },
/// An error that occurs when request dispatch fails.
Connector { err: ConnectorError },
/// An error that occurs when a response can't be deserialized.
Response { err: BoxError },
/// A general orchestrator error.
Other { err: BoxError },
}
@ -34,11 +40,26 @@ impl<E: Debug> OrchestratorError<E> {
Self::Operation { err }
}
/// Create a new `OrchestratorError` from an [`InterceptorError`].
/// Create a new `OrchestratorError::Interceptor` from an [`InterceptorError`].
pub fn interceptor(err: InterceptorError) -> Self {
Self::Interceptor { err }
}
/// Create a new `OrchestratorError::Timeout` from a [`BoxError`].
pub fn timeout(err: BoxError) -> Self {
Self::Timeout { err }
}
/// Create a new `OrchestratorError::Response` from a [`BoxError`].
pub fn response(err: BoxError) -> Self {
Self::Response { err }
}
/// Create a new `OrchestratorError::Connector` from a [`ConnectorError`].
pub fn connector(err: ConnectorError) -> Self {
Self::Connector { err }
}
/// Convert the `OrchestratorError` into `Some` operation specific error if it is one. Otherwise,
/// return `None`.
pub fn as_operation_error(&self) -> Option<&E> {
@ -72,6 +93,9 @@ impl<E: Debug> OrchestratorError<E> {
debug_assert!(phase.is_after_deserialization(), "operation errors are a result of successfully receiving and parsing a response from the server. Therefore, we must be in the 'After Deserialization' phase.");
SdkError::service_error(err, response.expect("phase has a response"))
}
Self::Connector { err } => SdkError::dispatch_failure(err),
Self::Timeout { err } => SdkError::timeout_error(err),
Self::Response { err } => SdkError::response_error(err, response.unwrap()),
Self::Other { err } => {
use Phase::*;
match phase {
@ -111,15 +135,6 @@ where
}
}
impl<E> From<BoxError> for OrchestratorError<E>
where
E: Debug + std::error::Error + 'static,
{
fn from(err: BoxError) -> Self {
Self::other(err)
}
}
impl From<TypeErasedError> for OrchestratorError<TypeErasedError> {
fn from(err: TypeErasedError) -> Self {
Self::operation(err)

View File

@ -20,6 +20,16 @@ pub enum ShouldAttempt {
YesAfterDelay(Duration),
}
#[cfg(feature = "test-util")]
impl ShouldAttempt {
pub fn expect_delay(self) -> Duration {
match self {
ShouldAttempt::YesAfterDelay(delay) => delay,
_ => panic!("Expected this to be the `YesAfterDelay` variant but it was the `{self:?}` variant instead"),
}
}
}
pub trait RetryStrategy: Send + Sync + Debug {
fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError>;
@ -31,7 +41,7 @@ pub trait RetryStrategy: Send + Sync + Debug {
}
#[non_exhaustive]
#[derive(Eq, PartialEq, Debug)]
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum RetryReason {
Error(ErrorKind),
Explicit(Duration),
@ -72,10 +82,10 @@ impl RetryClassifiers {
}
impl ClassifyRetry for RetryClassifiers {
fn classify_retry(&self, error: &InterceptorContext) -> Option<RetryReason> {
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
// return the first non-None result
self.inner.iter().find_map(|cr| {
let maybe_reason = cr.classify_retry(error);
let maybe_reason = cr.classify_retry(ctx);
match maybe_reason.as_ref() {
Some(reason) => trace!(

View File

@ -1,13 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Code for rate-limiting smithy clients.
pub mod error;
pub mod token;
pub mod token_bucket;
pub use token::Token;
pub use token_bucket::TokenBucket;

View File

@ -1,50 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Errors related to rate limiting
use std::fmt;
/// Errors related to a token bucket.
#[derive(Debug)]
pub struct RateLimitingError {
kind: ErrorKind,
}
impl RateLimitingError {
/// An error that occurs when no tokens are left in the bucket.
pub fn no_tokens() -> Self {
Self {
kind: ErrorKind::NoTokens,
}
}
/// An error that occurs due to a bug in the code. Please report bugs you encounter.
pub fn bug(s: impl ToString) -> Self {
Self {
kind: ErrorKind::Bug(s.to_string()),
}
}
}
#[derive(Debug)]
enum ErrorKind {
/// A token was requested but there were no tokens left in the bucket.
NoTokens,
/// This error should never occur and is a bug. Please report it.
Bug(String),
}
impl fmt::Display for RateLimitingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use ErrorKind::*;
match &self.kind {
NoTokens => write!(f, "No more tokens are left in the bucket."),
Bug(msg) => write!(f, "you've encountered a bug that needs reporting: {}", msg),
}
}
}
impl std::error::Error for RateLimitingError {}

View File

@ -1,65 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Types and traits related to token buckets. Token buckets are used to limit the amount of
//! requests a client sends in order to avoid getting throttled. Token buckets can also act as a
//! form of concurrency control if a token is required to send a new request (as opposed to retry
//! requests only).
use tokio::sync::OwnedSemaphorePermit;
/// A trait implemented by types that represent a token dispensed from a [`TokenBucket`](super::TokenBucket).
pub trait Token {
/// Release this token back to the bucket. This should be called if the related request succeeds.
fn release(self);
/// Forget this token, forever banishing it to the shadow realm, from whence no tokens return.
/// This should be called if the related request fails.
fn forget(self);
}
/// The token type of [`Standard`].
#[derive(Debug)]
pub struct Standard {
permit: Option<OwnedSemaphorePermit>,
}
impl Standard {
pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self {
Self {
permit: Some(permit),
}
}
// Return an "empty" token for times when you need to return a token but there's no "cost"
// associated with an action.
pub(crate) fn empty() -> Self {
Self { permit: None }
}
}
impl Token for Standard {
fn release(self) {
drop(self.permit)
}
fn forget(self) {
if let Some(permit) = self.permit {
permit.forget()
}
}
}
#[cfg(test)]
mod tests {
use super::Standard as Token;
use crate::client::retries::rate_limiting::token_bucket::Standard as TokenBucket;
#[test]
fn token_bucket_trait_is_dyn_safe() {
let _tb: Box<dyn crate::client::retries::rate_limiting::TokenBucket<Token = Token>> =
Box::new(TokenBucket::builder().build());
}
}

View File

@ -1,235 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! A token bucket intended for use with the standard smithy client retry policy.
use super::error::RateLimitingError;
use super::token;
use super::Token;
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use std::sync::Arc;
use tokio::sync::Semaphore;
use tokio::sync::TryAcquireError;
/// The default number of tokens to start with
const STANDARD_INITIAL_RETRY_TOKENS: usize = 500;
/// The amount of tokens to remove from the bucket when a timeout error occurs
const STANDARD_TIMEOUT_ERROR_RETRY_COST: u32 = 10;
/// The amount of tokens to remove from the bucket when a throttling error occurs
const STANDARD_RETRYABLE_ERROR_RETRY_COST: u32 = 5;
/// This trait is implemented by types that act as token buckets. Token buckets are used to regulate
/// the amount of requests sent by clients. Different token buckets may apply different strategies
/// to manage the number of tokens in a bucket.
///
/// related: [`Token`], [`RateLimitingError`]
pub trait TokenBucket {
/// The type of tokens this bucket dispenses.
type Token: Token;
/// Attempt to acquire a token from the bucket. This will fail if the bucket has no more tokens.
fn try_acquire(
&self,
previous_response_kind: Option<RetryKind>,
) -> Result<Self::Token, RateLimitingError>;
/// Get the number of available tokens in the bucket.
fn available(&self) -> usize;
/// Refill the bucket with the given number of tokens.
fn refill(&self, tokens: usize);
}
/// A token bucket implementation that uses a `tokio::sync::Semaphore` to track the number of tokens.
///
/// - Whenever a request succeeds on the first try, `<success_on_first_try_refill_amount>` token(s)
/// are added back to the bucket.
/// - When a request fails with a timeout error, `<timeout_error_cost>` token(s)
/// are removed from the bucket.
/// - When a request fails with a retryable error, `<retryable_error_cost>` token(s)
/// are removed from the bucket.
///
/// The number of tokens in the bucket will always be >= `0` and <= `<max_tokens>`.
#[derive(Clone, Debug)]
pub struct Standard {
inner: Arc<Semaphore>,
max_tokens: usize,
timeout_error_cost: u32,
retryable_error_cost: u32,
}
impl Standard {
/// Create a new `TokenBucket` using builder methods.
pub fn builder() -> Builder {
Builder::default()
}
}
/// A builder for `TokenBucket`s.
#[derive(Default, Debug)]
pub struct Builder {
starting_tokens: Option<usize>,
max_tokens: Option<usize>,
timeout_error_cost: Option<u32>,
retryable_error_cost: Option<u32>,
}
impl Builder {
/// The number of tokens the bucket will start with. Defaults to 500.
pub fn starting_tokens(mut self, starting_tokens: usize) -> Self {
self.starting_tokens = Some(starting_tokens);
self
}
/// The maximum number of tokens that the bucket can hold.
/// Defaults to the value of `starting_tokens`.
pub fn max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = Some(max_tokens);
self
}
/// How many tokens to remove from the bucket when a request fails due to a timeout error.
/// Defaults to 10.
pub fn timeout_error_cost(mut self, timeout_error_cost: u32) -> Self {
self.timeout_error_cost = Some(timeout_error_cost);
self
}
/// How many tokens to remove from the bucket when a request fails due to a retryable error that
/// isn't timeout-related. Defaults to 5.
pub fn retryable_error_cost(mut self, retryable_error_cost: u32) -> Self {
self.retryable_error_cost = Some(retryable_error_cost);
self
}
/// Build this builder. Unset fields will be set to their default values.
pub fn build(self) -> Standard {
let starting_tokens = self
.starting_tokens
.unwrap_or(STANDARD_INITIAL_RETRY_TOKENS);
let max_tokens = self.max_tokens.unwrap_or(starting_tokens);
let timeout_error_cost = self
.timeout_error_cost
.unwrap_or(STANDARD_TIMEOUT_ERROR_RETRY_COST);
let retryable_error_cost = self
.retryable_error_cost
.unwrap_or(STANDARD_RETRYABLE_ERROR_RETRY_COST);
Standard {
inner: Arc::new(Semaphore::new(starting_tokens)),
max_tokens,
timeout_error_cost,
retryable_error_cost,
}
}
}
impl TokenBucket for Standard {
type Token = token::Standard;
fn try_acquire(
&self,
previous_response_kind: Option<RetryKind>,
) -> Result<Self::Token, RateLimitingError> {
let number_of_tokens_to_acquire = match previous_response_kind {
None => {
// Return an empty token because the quota layer lifecycle expects a for each
// request even though the standard token bucket only requires tokens for retry
// attempts.
return Ok(token::Standard::empty());
}
Some(retry_kind) => match retry_kind {
RetryKind::Unnecessary => {
unreachable!("BUG: asked for a token to retry a successful request")
}
RetryKind::UnretryableFailure => {
unreachable!("BUG: asked for a token to retry an un-retryable request")
}
RetryKind::Explicit(_) => self.retryable_error_cost,
RetryKind::Error(error_kind) => match error_kind {
ErrorKind::ThrottlingError | ErrorKind::TransientError => {
self.timeout_error_cost
}
ErrorKind::ServerError => self.retryable_error_cost,
ErrorKind::ClientError => unreachable!(
"BUG: asked for a token to retry a request that failed due to user error"
),
_ => unreachable!(
"A new variant '{:?}' was added to ErrorKind, please handle it",
error_kind
),
},
_ => unreachable!(
"A new variant '{:?}' was added to RetryKind, please handle it",
retry_kind
),
},
};
match self
.inner
.clone()
.try_acquire_many_owned(number_of_tokens_to_acquire)
{
Ok(permit) => Ok(token::Standard::new(permit)),
Err(TryAcquireError::NoPermits) => Err(RateLimitingError::no_tokens()),
Err(other) => Err(RateLimitingError::bug(other.to_string())),
}
}
fn available(&self) -> usize {
self.inner.available_permits()
}
fn refill(&self, tokens: usize) {
// Ensure the bucket doesn't overflow by limiting the amount of tokens to add, if necessary.
let amount_to_add = (self.available() + tokens).min(self.max_tokens) - self.available();
if amount_to_add > 0 {
self.inner.add_permits(amount_to_add)
}
}
}
#[cfg(test)]
mod test {
use super::{Token, TokenBucket};
use super::{
STANDARD_INITIAL_RETRY_TOKENS, STANDARD_RETRYABLE_ERROR_RETRY_COST,
STANDARD_TIMEOUT_ERROR_RETRY_COST,
};
use aws_smithy_types::retry::{ErrorKind, RetryKind};
#[test]
fn bucket_works() {
let bucket = super::Standard::builder().build();
assert_eq!(bucket.available(), STANDARD_INITIAL_RETRY_TOKENS);
let token = bucket
.try_acquire(Some(RetryKind::Error(ErrorKind::ServerError)))
.unwrap();
assert_eq!(
bucket.available(),
STANDARD_INITIAL_RETRY_TOKENS - STANDARD_RETRYABLE_ERROR_RETRY_COST as usize
);
Box::new(token).release();
let token = bucket
.try_acquire(Some(RetryKind::Error(ErrorKind::TransientError)))
.unwrap();
assert_eq!(
bucket.available(),
STANDARD_INITIAL_RETRY_TOKENS - STANDARD_TIMEOUT_ERROR_RETRY_COST as usize
);
Box::new(token).forget();
assert_eq!(
bucket.available(),
STANDARD_INITIAL_RETRY_TOKENS - STANDARD_TIMEOUT_ERROR_RETRY_COST as usize
);
bucket.refill(STANDARD_TIMEOUT_ERROR_RETRY_COST as usize);
assert_eq!(bucket.available(), STANDARD_INITIAL_RETRY_TOKENS);
}
}

View File

@ -139,7 +139,7 @@ async fn try_op(
{
let request_serializer = cfg.request_serializer();
let input = ctx.take_input().expect("input set at this point");
let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg));
let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg).map_err(OrchestratorError::other));
ctx.set_request(request);
}
@ -171,10 +171,10 @@ async fn try_op(
// No, this request shouldn't be sent
Ok(ShouldAttempt::No) => {
let err: BoxError = "the retry strategy indicates that an initial request shouldn't be made, but it didn't specify why".into();
halt!([ctx] => err);
halt!([ctx] => OrchestratorError::other(err));
}
// No, we shouldn't make a request because...
Err(err) => halt!([ctx] => err),
Err(err) => halt!([ctx] => OrchestratorError::other(err)),
Ok(ShouldAttempt::YesAfterDelay(_)) => {
unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.")
}
@ -183,7 +183,7 @@ async fn try_op(
// Save a request checkpoint before we make the request. This will allow us to "rewind"
// the request in the case of retry attempts.
ctx.save_checkpoint();
for i in 0usize.. {
for i in 1usize.. {
debug!("beginning attempt #{i}");
// Break from the loop if we can't rewind the request's state. This will always succeed the
// first time, but will fail on subsequent iterations if the request body wasn't retryable.
@ -201,19 +201,21 @@ async fn try_op(
}
.maybe_timeout_with_config(attempt_timeout_config)
.await
.map_err(OrchestratorError::other);
.map_err(|err| OrchestratorError::timeout(err.into_source().unwrap()));
// We continue when encountering a timeout error. The retry classifier will decide what to do with it.
continue_on_err!([ctx] => maybe_timeout);
let retry_strategy = cfg.retry_strategy();
// If we got a retry strategy from the bag, ask it what to do.
// If no strategy was set, we won't retry.
let should_attempt = halt_on_err!(
[ctx] => retry_strategy
.map(|rs| rs.should_attempt_retry(ctx, cfg))
.unwrap_or(Ok(ShouldAttempt::No)
));
let should_attempt = match retry_strategy {
Some(retry_strategy) => halt_on_err!(
[ctx] => retry_strategy.should_attempt_retry(ctx, cfg).map_err(OrchestratorError::other)
),
None => ShouldAttempt::No,
};
match should_attempt {
// Yes, let's retry the request
ShouldAttempt::Yes => continue,
@ -241,11 +243,11 @@ async fn try_attempt(
stop_point: StopPoint,
) {
halt_on_err!([ctx] => interceptors.read_before_attempt(ctx, cfg));
halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg));
halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg).map_err(OrchestratorError::other));
halt_on_err!([ctx] => interceptors.modify_before_signing(ctx, cfg));
halt_on_err!([ctx] => interceptors.read_before_signing(ctx, cfg));
halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await);
halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await.map_err(OrchestratorError::other));
halt_on_err!([ctx] => interceptors.read_after_signing(ctx, cfg));
halt_on_err!([ctx] => interceptors.modify_before_transmit(ctx, cfg));
@ -261,7 +263,12 @@ async fn try_attempt(
ctx.enter_transmit_phase();
let call_result = halt_on_err!([ctx] => {
let request = ctx.take_request().expect("set during serialization");
cfg.connection().call(request).await
cfg.connection().call(request).await.map_err(|err| {
match err.downcast() {
Ok(connector_error) => OrchestratorError::connector(*connector_error),
Err(box_err) => OrchestratorError::other(box_err)
}
})
});
ctx.set_response(call_result);
ctx.enter_before_deserialization_phase();
@ -279,7 +286,7 @@ async fn try_attempt(
None => read_body(response)
.instrument(debug_span!("read_body"))
.await
.map_err(OrchestratorError::other)
.map_err(OrchestratorError::response)
.and_then(|_| response_deserializer.deserialize_nonstreaming(response)),
}
}

View File

@ -3,8 +3,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
use std::borrow::Cow;
@ -76,17 +76,17 @@ where
Ok(_) => return None,
Err(err) => err,
};
// Check that the error is an operation error
let error = error.as_operation_error()?;
// Downcast the error
let error = error.downcast_ref::<SdkError<E>>()?;
match error {
SdkError::TimeoutError(_) => Some(RetryReason::Error(ErrorKind::TransientError)),
SdkError::ResponseError { .. } => Some(RetryReason::Error(ErrorKind::TransientError)),
SdkError::DispatchFailure(err) if (err.is_timeout() || err.is_io()) => {
OrchestratorError::Response { .. } | OrchestratorError::Timeout { .. } => {
Some(RetryReason::Error(ErrorKind::TransientError))
}
SdkError::DispatchFailure(err) => err.is_other().map(RetryReason::Error),
OrchestratorError::Connector { err } if err.is_timeout() || err.is_io() => {
Some(RetryReason::Error(ErrorKind::TransientError))
}
OrchestratorError::Connector { err } if err.is_other().is_some() => {
err.is_other().map(RetryReason::Error)
}
_ => None,
}
}
@ -152,8 +152,6 @@ mod test {
HttpStatusCodeClassifier, ModeledAsRetryableClassifier,
};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation;
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
@ -242,11 +240,10 @@ mod test {
#[test]
fn classify_response_error() {
let policy = SmithyErrorClassifier::<UnmodeledError>::new();
let test_response = http::Response::new("OK").map(SdkBody::from);
let err: SdkError<UnmodeledError> =
SdkError::response_error(UnmodeledError, operation::Response::new(test_response));
let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err))));
ctx.set_output_or_error(Err(OrchestratorError::response(
"I am a response error".into(),
)));
assert_eq!(
policy.classify_retry(&ctx),
Some(RetryReason::Error(ErrorKind::TransientError)),
@ -256,9 +253,10 @@ mod test {
#[test]
fn test_timeout_error() {
let policy = SmithyErrorClassifier::<UnmodeledError>::new();
let err: SdkError<UnmodeledError> = SdkError::timeout_error("blah");
let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err))));
ctx.set_output_or_error(Err(OrchestratorError::timeout(
"I am a timeout error".into(),
)));
assert_eq!(
policy.classify_retry(&ctx),
Some(RetryReason::Error(ErrorKind::TransientError)),

View File

@ -6,7 +6,7 @@
#[cfg(feature = "test-util")]
mod fixed_delay;
mod never;
mod standard;
pub(crate) mod standard;
#[cfg(feature = "test-util")]
pub use fixed_delay::FixedDelayRetryStrategy;

View File

@ -3,6 +3,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
use crate::client::retries::strategy::standard::ReleaseResult::{
APermitWasReleased, NoPermitWasReleased,
};
use crate::client::runtime_plugin::standard_token_bucket::StandardTokenBucket;
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors};
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
@ -11,16 +15,21 @@ use aws_smithy_runtime_api::client::retries::{
};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::retry::RetryConfig;
use std::sync::Mutex;
use std::time::Duration;
use tokio::sync::OwnedSemaphorePermit;
// The initial attempt, plus three retries.
const DEFAULT_MAX_ATTEMPTS: usize = 4;
#[derive(Debug)]
pub struct StandardRetryStrategy {
max_attempts: usize,
initial_backoff: Duration,
max_backoff: Duration,
// Retry settings
base: fn() -> f64,
initial_backoff: Duration,
max_attempts: usize,
max_backoff: Duration,
retry_permit: Mutex<Option<OwnedSemaphorePermit>>,
}
impl StandardRetryStrategy {
@ -45,6 +54,36 @@ impl StandardRetryStrategy {
self.initial_backoff = initial_backoff;
self
}
pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
self.max_backoff = max_backoff;
self
}
fn release_retry_permit(&self) -> ReleaseResult {
let mut retry_permit = self.retry_permit.lock().unwrap();
match retry_permit.take() {
Some(p) => {
drop(p);
APermitWasReleased
}
None => NoPermitWasReleased,
}
}
fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) {
let mut old_retry_permit = self.retry_permit.lock().unwrap();
if let Some(p) = old_retry_permit.replace(new_retry_permit) {
// Whenever we set a new retry permit and it replaces the old one, we need to "forget"
// the old permit, removing it from the bucket forever.
p.forget()
}
}
}
enum ReleaseResult {
APermitWasReleased,
NoPermitWasReleased,
}
impl Default for StandardRetryStrategy {
@ -55,13 +94,14 @@ impl Default for StandardRetryStrategy {
// by default, use a random base for exponential backoff
base: fastrand::f64,
initial_backoff: Duration::from_secs(1),
retry_permit: Mutex::new(None),
}
}
}
impl RetryStrategy for StandardRetryStrategy {
// TODO(token-bucket) add support for optional cross-request token bucket
fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> {
// The standard token bucket is only ever considered for retry requests.
Ok(ShouldAttempt::Yes)
}
@ -74,18 +114,31 @@ impl RetryStrategy for StandardRetryStrategy {
let output_or_error = ctx.output_or_error().expect(
"This must never be called without reaching the point where the result exists.",
);
let token_bucket = cfg.get::<StandardTokenBucket>();
if output_or_error.is_ok() {
tracing::debug!("request succeeded, no retry necessary");
if let Some(tb) = token_bucket {
// If this retry strategy is holding any permits, release them back to the bucket.
if let NoPermitWasReleased = self.release_retry_permit() {
// In the event that there was no retry permit to release, we generate new
// permits from nothing. We do this to make up for permits we had to "forget".
// Otherwise, repeated retries would empty the bucket and nothing could fill it
// back up again.
tb.regenerate_a_token();
}
}
return Ok(ShouldAttempt::No);
}
// Check if we're out of attempts
let request_attempts: &RequestAttempts = cfg
.get()
.expect("at least one request attempt is made before any retry is attempted");
if request_attempts.attempts() >= self.max_attempts {
let request_attempts = cfg
.get::<RequestAttempts>()
.expect("at least one request attempt is made before any retry is attempted")
.attempts();
if request_attempts >= self.max_attempts {
tracing::trace!(
attempts = request_attempts.attempts(),
attempts = request_attempts,
max_attempts = self.max_attempts,
"not retrying because we are out of attempts"
);
@ -95,9 +148,24 @@ impl RetryStrategy for StandardRetryStrategy {
// Run the classifiers against the context to determine if we should retry
let retry_classifiers = cfg.retry_classifiers();
let retry_reason = retry_classifiers.classify_retry(ctx);
// Calculate the appropriate backoff time.
let backoff = match retry_reason {
Some(RetryReason::Explicit(dur)) => dur,
Some(RetryReason::Error(_)) => {
Some(RetryReason::Error(kind)) => {
// If a token bucket was set, and the RetryReason IS NOT explicit, attempt to acquire a retry permit.
if let Some(tb) = token_bucket {
match tb.acquire(&kind) {
Some(permit) => self.set_retry_permit(permit),
None => {
tracing::debug!(
"attempt #{request_attempts} failed with {kind:?}; However, no retry permits are available, so no retry will be attempted.",
);
return Ok(ShouldAttempt::No);
}
}
};
let backoff = calculate_exponential_backoff(
// Generate a random base multiplier to create jitter
(self.base)(),
@ -105,16 +173,14 @@ impl RetryStrategy for StandardRetryStrategy {
self.initial_backoff.as_secs_f64(),
// `self.local.attempts` tracks number of requests made including the initial request
// The initial attempt shouldn't count towards backoff calculations so we subtract it
(request_attempts.attempts() - 1) as u32,
(request_attempts - 1) as u32,
);
Duration::from_secs_f64(backoff).min(self.max_backoff)
}
Some(_) => {
unreachable!("RetryReason is non-exhaustive. Therefore, we need to cover this unreachable case.")
}
Some(_) => unreachable!("RetryReason is non-exhaustive"),
None => {
tracing::trace!(
attempts = request_attempts.attempts(),
attempts = request_attempts,
max_attempts = self.max_attempts,
"encountered unretryable error"
);
@ -123,8 +189,7 @@ impl RetryStrategy for StandardRetryStrategy {
};
tracing::debug!(
"attempt {} failed with {:?}; retrying after {:?}",
request_attempts.attempts(),
"attempt #{request_attempts} failed with {:?}; retrying after {:?}",
retry_reason.expect("the match statement above ensures this is not None"),
backoff
);
@ -139,16 +204,23 @@ fn calculate_exponential_backoff(base: f64, initial_backoff: f64, retry_attempts
#[cfg(test)]
mod tests {
use super::{ShouldAttempt, StandardRetryStrategy};
use super::{calculate_exponential_backoff, ShouldAttempt, StandardRetryStrategy};
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError};
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
use aws_smithy_runtime_api::client::retries::{AlwaysRetry, RetryClassifiers, RetryStrategy};
use aws_smithy_runtime_api::client::retries::{
AlwaysRetry, ClassifyRetry, RetryClassifiers, RetryReason, RetryStrategy,
};
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::retry::ErrorKind;
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
use aws_smithy_types::type_erasure::TypeErasedBox;
use std::fmt;
use std::sync::Mutex;
use std::time::Duration;
#[cfg(feature = "test-util")]
use crate::client::runtime_plugin::standard_token_bucket::StandardTokenBucket;
#[test]
fn no_retry_necessary_for_ok_result() {
let cfg = ConfigBag::base();
@ -221,4 +293,351 @@ mod tests {
.expect("method is infallible for this use");
assert_eq!(ShouldAttempt::No, actual);
}
#[derive(Debug)]
struct ServerError;
impl fmt::Display for ServerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "OperationError")
}
}
impl std::error::Error for ServerError {}
impl ProvideErrorKind for ServerError {
fn retryable_error_kind(&self) -> Option<ErrorKind> {
Some(ErrorKind::ServerError)
}
fn code(&self) -> Option<&str> {
None
}
}
#[derive(Debug)]
struct PresetReasonRetryClassifier {
retry_reasons: Mutex<Vec<RetryReason>>,
}
#[cfg(feature = "test-util")]
impl PresetReasonRetryClassifier {
fn new(mut retry_reasons: Vec<RetryReason>) -> Self {
// We'll pop the retry_reasons in reverse order so we reverse the list to fix that.
retry_reasons.reverse();
Self {
retry_reasons: Mutex::new(retry_reasons),
}
}
}
impl ClassifyRetry for PresetReasonRetryClassifier {
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
if ctx.output_or_error().map(|it| it.is_ok()).unwrap_or(false) {
return None;
}
let mut retry_reasons = self.retry_reasons.lock().unwrap();
if retry_reasons.len() == 1 {
Some(retry_reasons.first().unwrap().clone())
} else {
retry_reasons.pop()
}
}
fn name(&self) -> &'static str {
"Always returns a preset retry reason"
}
}
#[cfg(feature = "test-util")]
fn setup_test(retry_reasons: Vec<RetryReason>) -> (ConfigBag, InterceptorContext) {
let mut cfg = ConfigBag::base();
cfg.interceptor_state().set_retry_classifiers(
RetryClassifiers::new()
.with_classifier(PresetReasonRetryClassifier::new(retry_reasons)),
);
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
// This type doesn't matter b/c the classifier will just return whatever we tell it to.
ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
(cfg, ctx)
}
#[cfg(feature = "test-util")]
#[test]
fn eventual_success() {
let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
cfg.interceptor_state().put(StandardTokenBucket::default());
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
cfg.interceptor_state().put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
cfg.interceptor_state().put(RequestAttempts::new(3));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 495);
}
#[cfg(feature = "test-util")]
#[test]
fn no_more_attempts() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(3);
cfg.interceptor_state().put(StandardTokenBucket::default());
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
cfg.interceptor_state().put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
cfg.interceptor_state().put(RequestAttempts::new(3));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 490);
}
#[cfg(feature = "test-util")]
#[test]
fn no_quota() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
cfg.interceptor_state().put(StandardTokenBucket::new(5));
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
cfg.interceptor_state().put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 0);
cfg.interceptor_state().put(RequestAttempts::new(2));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 0);
}
#[cfg(feature = "test-util")]
#[test]
fn quota_replenishes_on_success() {
let (mut cfg, mut ctx) = setup_test(vec![
RetryReason::Error(ErrorKind::TransientError),
RetryReason::Explicit(Duration::from_secs(1)),
]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
cfg.interceptor_state().put(StandardTokenBucket::new(100));
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
cfg.interceptor_state().put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 90);
cfg.interceptor_state().put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 90);
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
cfg.interceptor_state().put(RequestAttempts::new(3));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 100);
}
#[cfg(feature = "test-util")]
#[test]
fn quota_replenishes_on_first_try_success() {
const PERMIT_COUNT: usize = 20;
let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::TransientError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(usize::MAX);
cfg.interceptor_state()
.put(StandardTokenBucket::new(PERMIT_COUNT));
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
let mut attempt = 1;
// Drain all available permits with failed attempts
while token_bucket.available_permits() > 0 {
// Draining should complete in 2 attempts
if attempt > 2 {
panic!("This test should have completed by now (drain)");
}
cfg.interceptor_state().put(RequestAttempts::new(attempt));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
attempt += 1;
}
// Forget the permit so that we can only refill by "success on first try".
let permit = strategy.retry_permit.lock().unwrap().take().unwrap();
permit.forget();
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
// Replenish permits until we get back to `PERMIT_COUNT`
while token_bucket.available_permits() < PERMIT_COUNT {
if attempt > 23 {
panic!("This test should have completed by now (fillup)");
}
cfg.interceptor_state().put(RequestAttempts::new(attempt));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
attempt += 1;
}
assert_eq!(attempt, 23);
assert_eq!(token_bucket.available_permits(), PERMIT_COUNT);
}
#[cfg(feature = "test-util")]
#[test]
fn backoff_timing() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
cfg.interceptor_state().put(StandardTokenBucket::default());
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
cfg.interceptor_state().put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
cfg.interceptor_state().put(RequestAttempts::new(3));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(4));
assert_eq!(token_bucket.available_permits(), 485);
cfg.interceptor_state().put(RequestAttempts::new(4));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(8));
assert_eq!(token_bucket.available_permits(), 480);
cfg.interceptor_state().put(RequestAttempts::new(5));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 480);
}
#[cfg(feature = "test-util")]
#[test]
fn max_backoff_time() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5)
.with_initial_backoff(Duration::from_secs(1))
.with_max_backoff(Duration::from_secs(3));
cfg.interceptor_state().put(StandardTokenBucket::default());
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
cfg.interceptor_state().put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
cfg.interceptor_state().put(RequestAttempts::new(3));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(3));
assert_eq!(token_bucket.available_permits(), 485);
cfg.interceptor_state().put(RequestAttempts::new(4));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(3));
assert_eq!(token_bucket.available_permits(), 480);
cfg.interceptor_state().put(RequestAttempts::new(5));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 480);
}
#[test]
fn calculate_exponential_backoff_where_initial_backoff_is_one() {
let initial_backoff = 1.0;
for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() {
let actual_backoff =
calculate_exponential_backoff(1.0, initial_backoff, attempt as u32);
assert_eq!(expected_backoff, actual_backoff);
}
}
#[test]
fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() {
let initial_backoff = 3.0;
for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() {
let actual_backoff =
calculate_exponential_backoff(1.0, initial_backoff, attempt as u32);
assert_eq!(expected_backoff, actual_backoff);
}
}
#[test]
fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() {
let initial_backoff = 0.03;
for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() {
let actual_backoff =
calculate_exponential_backoff(1.0, initial_backoff, attempt as u32);
assert_eq!(expected_backoff, actual_backoff);
}
}
}

View File

@ -5,3 +5,5 @@
#[cfg(feature = "anonymous-auth")]
pub mod anonymous_auth;
pub mod standard_token_bucket;

View File

@ -0,0 +1,100 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::{FrozenLayer, Layer, Storable, StoreReplace};
use aws_smithy_types::retry::ErrorKind;
use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::trace;
/// A [RuntimePlugin] to provide a standard token bucket, usable by the
/// [`StandardRetryStrategy`](crate::client::retries::strategy::standard::StandardRetryStrategy).
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct StandardTokenBucketRuntimePlugin {
token_bucket: StandardTokenBucket,
}
impl StandardTokenBucketRuntimePlugin {
pub fn new(initial_tokens: usize) -> Self {
Self {
token_bucket: StandardTokenBucket::new(initial_tokens),
}
}
}
impl RuntimePlugin for StandardTokenBucketRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
let mut cfg = Layer::new("standard token bucket");
cfg.store_put(self.token_bucket.clone());
Some(cfg.freeze())
}
}
const DEFAULT_CAPACITY: usize = 500;
const RETRY_COST: u32 = 5;
const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
const PERMIT_REGENERATION_AMOUNT: usize = 1;
#[derive(Clone, Debug)]
pub(crate) struct StandardTokenBucket {
semaphore: Arc<Semaphore>,
max_permits: usize,
timeout_retry_cost: u32,
retry_cost: u32,
}
impl Storable for StandardTokenBucket {
type Storer = StoreReplace<Self>;
}
impl Default for StandardTokenBucket {
fn default() -> Self {
Self {
semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
max_permits: DEFAULT_CAPACITY,
timeout_retry_cost: RETRY_TIMEOUT_COST,
retry_cost: RETRY_COST,
}
}
}
impl StandardTokenBucket {
pub(crate) fn new(initial_quota: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(initial_quota)),
max_permits: initial_quota,
retry_cost: RETRY_COST,
timeout_retry_cost: RETRY_TIMEOUT_COST,
}
}
pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
let retry_cost = if err == &ErrorKind::TransientError {
self.timeout_retry_cost
} else {
self.retry_cost
};
self.semaphore
.clone()
.try_acquire_many_owned(retry_cost)
.ok()
}
pub(crate) fn regenerate_a_token(&self) {
if self.semaphore.available_permits() < (self.max_permits) {
trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
}
}
#[cfg(all(test, feature = "test-util"))]
pub(crate) fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
}

View File

@ -143,6 +143,7 @@ pub struct RetryConfigBuilder {
mode: Option<RetryMode>,
max_attempts: Option<u32>,
initial_backoff: Option<Duration>,
max_backoff: Option<Duration>,
reconnect_mode: Option<ReconnectMode>,
}
@ -212,6 +213,18 @@ impl RetryConfigBuilder {
self
}
/// Set the max_backoff duration. This duration should be non-zero.
pub fn set_max_backoff(&mut self, max_backoff: Option<Duration>) -> &mut Self {
self.max_backoff = max_backoff;
self
}
/// Set the max_backoff duration. This duration should be non-zero.
pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
self.set_max_backoff(Some(max_backoff));
self
}
/// Merge two builders together. Values from `other` will only be used as a fallback for values
/// from `self` Useful for merging configs from different sources together when you want to
/// handle "precedence" per value instead of at the config level
@ -233,6 +246,7 @@ impl RetryConfigBuilder {
mode: self.mode.or(other.mode),
max_attempts: self.max_attempts.or(other.max_attempts),
initial_backoff: self.initial_backoff.or(other.initial_backoff),
max_backoff: self.max_backoff.or(other.max_backoff),
reconnect_mode: self.reconnect_mode.or(other.reconnect_mode),
}
}
@ -248,6 +262,7 @@ impl RetryConfigBuilder {
reconnect_mode: self
.reconnect_mode
.unwrap_or(ReconnectMode::ReconnectOnTransientError),
max_backoff: self.max_backoff.unwrap_or_else(|| Duration::from_secs(20)),
}
}
}
@ -259,6 +274,7 @@ pub struct RetryConfig {
mode: RetryMode,
max_attempts: u32,
initial_backoff: Duration,
max_backoff: Duration,
reconnect_mode: ReconnectMode,
}
@ -286,6 +302,7 @@ impl RetryConfig {
max_attempts: 3,
initial_backoff: Duration::from_secs(1),
reconnect_mode: ReconnectMode::ReconnectOnTransientError,
max_backoff: Duration::from_secs(20),
}
}