mirror of https://github.com/smithy-lang/smithy-rs
Update standard orchestrator retries with token bucket and more tests (#2764)
## Motivation and Context <!--- Why is this change required? What problem does it solve? --> <!--- If it fixes an open issue, please link to the issue here --> addresses #2743 ## Description <!--- Describe your changes in detail --> - add more standard retry tests - add optional standard retries token bucket ## Testing <!--- Please describe in detail how you tested your changes --> <!--- Include details of your testing environment, and the tests you ran to --> <!--- see how your change affects other areas of the code, etc. --> tests are included ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Co-authored-by: John DiSanti <jdisanti@amazon.com>
This commit is contained in:
parent
5473192d3f
commit
312d190535
|
@ -465,7 +465,6 @@ async fn build_provider_chain(
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use crate::profile::credentials::Builder;
|
||||
use crate::test_case::TestEnvironment;
|
||||
|
||||
|
|
|
@ -146,20 +146,6 @@ open class OperationGenerator(
|
|||
if (codegenContext.smithyRuntimeMode.generateOrchestrator) {
|
||||
rustTemplate(
|
||||
"""
|
||||
pub(crate) fn register_runtime_plugins(
|
||||
runtime_plugins: #{RuntimePlugins},
|
||||
handle: #{Arc}<crate::client::Handle>,
|
||||
config_override: #{Option}<crate::config::Builder>,
|
||||
) -> #{RuntimePlugins} {
|
||||
#{register_default_runtime_plugins}(
|
||||
runtime_plugins,
|
||||
#{Box}::new(Self::new()) as _,
|
||||
handle,
|
||||
config_override
|
||||
)
|
||||
#{additional_runtime_plugins}
|
||||
}
|
||||
|
||||
pub(crate) async fn orchestrate(
|
||||
runtime_plugins: &#{RuntimePlugins},
|
||||
input: #{Input},
|
||||
|
@ -186,6 +172,20 @@ open class OperationGenerator(
|
|||
let input = #{TypedBox}::new(input).erase();
|
||||
#{invoke_with_stop_point}(input, runtime_plugins, stop_point).await
|
||||
}
|
||||
|
||||
pub(crate) fn register_runtime_plugins(
|
||||
runtime_plugins: #{RuntimePlugins},
|
||||
handle: #{Arc}<crate::client::Handle>,
|
||||
config_override: #{Option}<crate::config::Builder>,
|
||||
) -> #{RuntimePlugins} {
|
||||
#{register_default_runtime_plugins}(
|
||||
runtime_plugins,
|
||||
#{Box}::new(Self::new()) as _,
|
||||
handle,
|
||||
config_override
|
||||
)
|
||||
#{additional_runtime_plugins}
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
"Error" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Error"),
|
||||
|
|
|
@ -105,6 +105,8 @@ class ServiceRuntimePluginGenerator(
|
|||
"StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
|
||||
"default_connector" to client.resolve("conns::default_connector"),
|
||||
"require_connector" to client.resolve("conns::require_connector"),
|
||||
"TimeoutConfig" to smithyTypes.resolve("timeout::TimeoutConfig"),
|
||||
"RetryConfig" to smithyTypes.resolve("retry::RetryConfig"),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -142,20 +144,17 @@ class ServiceRuntimePluginGenerator(
|
|||
self.handle.conf.endpoint_resolver());
|
||||
cfg.set_endpoint_resolver(endpoint_resolver);
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Use the `store_append` method of ConfigBag to insert classifiers
|
||||
let retry_classifiers = #{RetryClassifiers}::new()
|
||||
#{retry_classifier_customizations};
|
||||
cfg.set_retry_classifiers(retry_classifiers);
|
||||
// TODO(enableNewSmithyRuntime): Make it possible to set retry classifiers at the service level.
|
||||
// Retry classifiers can also be set at the operation level and those should be added to the
|
||||
// list of classifiers defined here, rather than replacing them.
|
||||
|
||||
let sleep_impl = self.handle.conf.sleep_impl();
|
||||
let timeout_config = self.handle.conf.timeout_config();
|
||||
let retry_config = self.handle.conf.retry_config();
|
||||
let timeout_config = self.handle.conf.timeout_config().cloned().unwrap_or_else(|| #{TimeoutConfig}::disabled());
|
||||
let retry_config = self.handle.conf.retry_config().cloned().unwrap_or_else(|| #{RetryConfig}::disabled());
|
||||
|
||||
if let Some(retry_config) = retry_config {
|
||||
cfg.set_retry_strategy(#{StandardRetryStrategy}::new(retry_config));
|
||||
}
|
||||
cfg.set_retry_strategy(#{StandardRetryStrategy}::new(&retry_config));
|
||||
|
||||
let connector_settings = timeout_config.map(#{ConnectorSettings}::from_timeout_config).unwrap_or_default();
|
||||
let connector_settings = #{ConnectorSettings}::from_timeout_config(&timeout_config);
|
||||
if let Some(connection) = self.handle.conf.http_connector()
|
||||
.and_then(|c| c.connector(&connector_settings, sleep_impl.clone()))
|
||||
.or_else(|| #{default_connector}(&connector_settings, sleep_impl)) {
|
||||
|
@ -180,9 +179,6 @@ class ServiceRuntimePluginGenerator(
|
|||
"http_auth_scheme_customizations" to writable {
|
||||
writeCustomizations(customizations, ServiceRuntimePluginSection.HttpAuthScheme("cfg"))
|
||||
},
|
||||
"retry_classifier_customizations" to writable {
|
||||
writeCustomizations(customizations, ServiceRuntimePluginSection.RetryClassifier("cfg"))
|
||||
},
|
||||
"additional_config" to writable {
|
||||
writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg"))
|
||||
},
|
||||
|
|
|
@ -18,6 +18,12 @@ pub enum OrchestratorError<E> {
|
|||
Interceptor { err: InterceptorError },
|
||||
/// An error returned by a service.
|
||||
Operation { err: E },
|
||||
/// An error that occurs when a request times out.
|
||||
Timeout { err: BoxError },
|
||||
/// An error that occurs when request dispatch fails.
|
||||
Connector { err: ConnectorError },
|
||||
/// An error that occurs when a response can't be deserialized.
|
||||
Response { err: BoxError },
|
||||
/// A general orchestrator error.
|
||||
Other { err: BoxError },
|
||||
}
|
||||
|
@ -34,11 +40,26 @@ impl<E: Debug> OrchestratorError<E> {
|
|||
Self::Operation { err }
|
||||
}
|
||||
|
||||
/// Create a new `OrchestratorError` from an [`InterceptorError`].
|
||||
/// Create a new `OrchestratorError::Interceptor` from an [`InterceptorError`].
|
||||
pub fn interceptor(err: InterceptorError) -> Self {
|
||||
Self::Interceptor { err }
|
||||
}
|
||||
|
||||
/// Create a new `OrchestratorError::Timeout` from a [`BoxError`].
|
||||
pub fn timeout(err: BoxError) -> Self {
|
||||
Self::Timeout { err }
|
||||
}
|
||||
|
||||
/// Create a new `OrchestratorError::Response` from a [`BoxError`].
|
||||
pub fn response(err: BoxError) -> Self {
|
||||
Self::Response { err }
|
||||
}
|
||||
|
||||
/// Create a new `OrchestratorError::Connector` from a [`ConnectorError`].
|
||||
pub fn connector(err: ConnectorError) -> Self {
|
||||
Self::Connector { err }
|
||||
}
|
||||
|
||||
/// Convert the `OrchestratorError` into `Some` operation specific error if it is one. Otherwise,
|
||||
/// return `None`.
|
||||
pub fn as_operation_error(&self) -> Option<&E> {
|
||||
|
@ -72,6 +93,9 @@ impl<E: Debug> OrchestratorError<E> {
|
|||
debug_assert!(phase.is_after_deserialization(), "operation errors are a result of successfully receiving and parsing a response from the server. Therefore, we must be in the 'After Deserialization' phase.");
|
||||
SdkError::service_error(err, response.expect("phase has a response"))
|
||||
}
|
||||
Self::Connector { err } => SdkError::dispatch_failure(err),
|
||||
Self::Timeout { err } => SdkError::timeout_error(err),
|
||||
Self::Response { err } => SdkError::response_error(err, response.unwrap()),
|
||||
Self::Other { err } => {
|
||||
use Phase::*;
|
||||
match phase {
|
||||
|
@ -111,15 +135,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<E> From<BoxError> for OrchestratorError<E>
|
||||
where
|
||||
E: Debug + std::error::Error + 'static,
|
||||
{
|
||||
fn from(err: BoxError) -> Self {
|
||||
Self::other(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TypeErasedError> for OrchestratorError<TypeErasedError> {
|
||||
fn from(err: TypeErasedError) -> Self {
|
||||
Self::operation(err)
|
||||
|
|
|
@ -20,6 +20,16 @@ pub enum ShouldAttempt {
|
|||
YesAfterDelay(Duration),
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
impl ShouldAttempt {
|
||||
pub fn expect_delay(self) -> Duration {
|
||||
match self {
|
||||
ShouldAttempt::YesAfterDelay(delay) => delay,
|
||||
_ => panic!("Expected this to be the `YesAfterDelay` variant but it was the `{self:?}` variant instead"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RetryStrategy: Send + Sync + Debug {
|
||||
fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError>;
|
||||
|
||||
|
@ -31,7 +41,7 @@ pub trait RetryStrategy: Send + Sync + Debug {
|
|||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub enum RetryReason {
|
||||
Error(ErrorKind),
|
||||
Explicit(Duration),
|
||||
|
@ -72,10 +82,10 @@ impl RetryClassifiers {
|
|||
}
|
||||
|
||||
impl ClassifyRetry for RetryClassifiers {
|
||||
fn classify_retry(&self, error: &InterceptorContext) -> Option<RetryReason> {
|
||||
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
|
||||
// return the first non-None result
|
||||
self.inner.iter().find_map(|cr| {
|
||||
let maybe_reason = cr.classify_retry(error);
|
||||
let maybe_reason = cr.classify_retry(ctx);
|
||||
|
||||
match maybe_reason.as_ref() {
|
||||
Some(reason) => trace!(
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
//! Code for rate-limiting smithy clients.
|
||||
|
||||
pub mod error;
|
||||
pub mod token;
|
||||
pub mod token_bucket;
|
||||
|
||||
pub use token::Token;
|
||||
pub use token_bucket::TokenBucket;
|
|
@ -1,50 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
//! Errors related to rate limiting
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Errors related to a token bucket.
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimitingError {
|
||||
kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl RateLimitingError {
|
||||
/// An error that occurs when no tokens are left in the bucket.
|
||||
pub fn no_tokens() -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::NoTokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// An error that occurs due to a bug in the code. Please report bugs you encounter.
|
||||
pub fn bug(s: impl ToString) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Bug(s.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
/// A token was requested but there were no tokens left in the bucket.
|
||||
NoTokens,
|
||||
/// This error should never occur and is a bug. Please report it.
|
||||
Bug(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for RateLimitingError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use ErrorKind::*;
|
||||
match &self.kind {
|
||||
NoTokens => write!(f, "No more tokens are left in the bucket."),
|
||||
Bug(msg) => write!(f, "you've encountered a bug that needs reporting: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RateLimitingError {}
|
|
@ -1,65 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
//! Types and traits related to token buckets. Token buckets are used to limit the amount of
|
||||
//! requests a client sends in order to avoid getting throttled. Token buckets can also act as a
|
||||
//! form of concurrency control if a token is required to send a new request (as opposed to retry
|
||||
//! requests only).
|
||||
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
|
||||
/// A trait implemented by types that represent a token dispensed from a [`TokenBucket`](super::TokenBucket).
|
||||
pub trait Token {
|
||||
/// Release this token back to the bucket. This should be called if the related request succeeds.
|
||||
fn release(self);
|
||||
|
||||
/// Forget this token, forever banishing it to the shadow realm, from whence no tokens return.
|
||||
/// This should be called if the related request fails.
|
||||
fn forget(self);
|
||||
}
|
||||
|
||||
/// The token type of [`Standard`].
|
||||
#[derive(Debug)]
|
||||
pub struct Standard {
|
||||
permit: Option<OwnedSemaphorePermit>,
|
||||
}
|
||||
|
||||
impl Standard {
|
||||
pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self {
|
||||
Self {
|
||||
permit: Some(permit),
|
||||
}
|
||||
}
|
||||
|
||||
// Return an "empty" token for times when you need to return a token but there's no "cost"
|
||||
// associated with an action.
|
||||
pub(crate) fn empty() -> Self {
|
||||
Self { permit: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Token for Standard {
|
||||
fn release(self) {
|
||||
drop(self.permit)
|
||||
}
|
||||
|
||||
fn forget(self) {
|
||||
if let Some(permit) = self.permit {
|
||||
permit.forget()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Standard as Token;
|
||||
use crate::client::retries::rate_limiting::token_bucket::Standard as TokenBucket;
|
||||
|
||||
#[test]
|
||||
fn token_bucket_trait_is_dyn_safe() {
|
||||
let _tb: Box<dyn crate::client::retries::rate_limiting::TokenBucket<Token = Token>> =
|
||||
Box::new(TokenBucket::builder().build());
|
||||
}
|
||||
}
|
|
@ -1,235 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
//! A token bucket intended for use with the standard smithy client retry policy.
|
||||
|
||||
use super::error::RateLimitingError;
|
||||
use super::token;
|
||||
use super::Token;
|
||||
use aws_smithy_types::retry::{ErrorKind, RetryKind};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::TryAcquireError;
|
||||
|
||||
/// The default number of tokens to start with
|
||||
const STANDARD_INITIAL_RETRY_TOKENS: usize = 500;
|
||||
/// The amount of tokens to remove from the bucket when a timeout error occurs
|
||||
const STANDARD_TIMEOUT_ERROR_RETRY_COST: u32 = 10;
|
||||
/// The amount of tokens to remove from the bucket when a throttling error occurs
|
||||
const STANDARD_RETRYABLE_ERROR_RETRY_COST: u32 = 5;
|
||||
|
||||
/// This trait is implemented by types that act as token buckets. Token buckets are used to regulate
|
||||
/// the amount of requests sent by clients. Different token buckets may apply different strategies
|
||||
/// to manage the number of tokens in a bucket.
|
||||
///
|
||||
/// related: [`Token`], [`RateLimitingError`]
|
||||
pub trait TokenBucket {
|
||||
/// The type of tokens this bucket dispenses.
|
||||
type Token: Token;
|
||||
|
||||
/// Attempt to acquire a token from the bucket. This will fail if the bucket has no more tokens.
|
||||
fn try_acquire(
|
||||
&self,
|
||||
previous_response_kind: Option<RetryKind>,
|
||||
) -> Result<Self::Token, RateLimitingError>;
|
||||
|
||||
/// Get the number of available tokens in the bucket.
|
||||
fn available(&self) -> usize;
|
||||
|
||||
/// Refill the bucket with the given number of tokens.
|
||||
fn refill(&self, tokens: usize);
|
||||
}
|
||||
|
||||
/// A token bucket implementation that uses a `tokio::sync::Semaphore` to track the number of tokens.
|
||||
///
|
||||
/// - Whenever a request succeeds on the first try, `<success_on_first_try_refill_amount>` token(s)
|
||||
/// are added back to the bucket.
|
||||
/// - When a request fails with a timeout error, `<timeout_error_cost>` token(s)
|
||||
/// are removed from the bucket.
|
||||
/// - When a request fails with a retryable error, `<retryable_error_cost>` token(s)
|
||||
/// are removed from the bucket.
|
||||
///
|
||||
/// The number of tokens in the bucket will always be >= `0` and <= `<max_tokens>`.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Standard {
|
||||
inner: Arc<Semaphore>,
|
||||
max_tokens: usize,
|
||||
timeout_error_cost: u32,
|
||||
retryable_error_cost: u32,
|
||||
}
|
||||
|
||||
impl Standard {
|
||||
/// Create a new `TokenBucket` using builder methods.
|
||||
pub fn builder() -> Builder {
|
||||
Builder::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for `TokenBucket`s.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct Builder {
|
||||
starting_tokens: Option<usize>,
|
||||
max_tokens: Option<usize>,
|
||||
timeout_error_cost: Option<u32>,
|
||||
retryable_error_cost: Option<u32>,
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
/// The number of tokens the bucket will start with. Defaults to 500.
|
||||
pub fn starting_tokens(mut self, starting_tokens: usize) -> Self {
|
||||
self.starting_tokens = Some(starting_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// The maximum number of tokens that the bucket can hold.
|
||||
/// Defaults to the value of `starting_tokens`.
|
||||
pub fn max_tokens(mut self, max_tokens: usize) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// How many tokens to remove from the bucket when a request fails due to a timeout error.
|
||||
/// Defaults to 10.
|
||||
pub fn timeout_error_cost(mut self, timeout_error_cost: u32) -> Self {
|
||||
self.timeout_error_cost = Some(timeout_error_cost);
|
||||
self
|
||||
}
|
||||
|
||||
/// How many tokens to remove from the bucket when a request fails due to a retryable error that
|
||||
/// isn't timeout-related. Defaults to 5.
|
||||
pub fn retryable_error_cost(mut self, retryable_error_cost: u32) -> Self {
|
||||
self.retryable_error_cost = Some(retryable_error_cost);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build this builder. Unset fields will be set to their default values.
|
||||
pub fn build(self) -> Standard {
|
||||
let starting_tokens = self
|
||||
.starting_tokens
|
||||
.unwrap_or(STANDARD_INITIAL_RETRY_TOKENS);
|
||||
let max_tokens = self.max_tokens.unwrap_or(starting_tokens);
|
||||
let timeout_error_cost = self
|
||||
.timeout_error_cost
|
||||
.unwrap_or(STANDARD_TIMEOUT_ERROR_RETRY_COST);
|
||||
let retryable_error_cost = self
|
||||
.retryable_error_cost
|
||||
.unwrap_or(STANDARD_RETRYABLE_ERROR_RETRY_COST);
|
||||
|
||||
Standard {
|
||||
inner: Arc::new(Semaphore::new(starting_tokens)),
|
||||
max_tokens,
|
||||
timeout_error_cost,
|
||||
retryable_error_cost,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenBucket for Standard {
|
||||
type Token = token::Standard;
|
||||
|
||||
fn try_acquire(
|
||||
&self,
|
||||
previous_response_kind: Option<RetryKind>,
|
||||
) -> Result<Self::Token, RateLimitingError> {
|
||||
let number_of_tokens_to_acquire = match previous_response_kind {
|
||||
None => {
|
||||
// Return an empty token because the quota layer lifecycle expects a for each
|
||||
// request even though the standard token bucket only requires tokens for retry
|
||||
// attempts.
|
||||
return Ok(token::Standard::empty());
|
||||
}
|
||||
|
||||
Some(retry_kind) => match retry_kind {
|
||||
RetryKind::Unnecessary => {
|
||||
unreachable!("BUG: asked for a token to retry a successful request")
|
||||
}
|
||||
RetryKind::UnretryableFailure => {
|
||||
unreachable!("BUG: asked for a token to retry an un-retryable request")
|
||||
}
|
||||
RetryKind::Explicit(_) => self.retryable_error_cost,
|
||||
RetryKind::Error(error_kind) => match error_kind {
|
||||
ErrorKind::ThrottlingError | ErrorKind::TransientError => {
|
||||
self.timeout_error_cost
|
||||
}
|
||||
ErrorKind::ServerError => self.retryable_error_cost,
|
||||
ErrorKind::ClientError => unreachable!(
|
||||
"BUG: asked for a token to retry a request that failed due to user error"
|
||||
),
|
||||
_ => unreachable!(
|
||||
"A new variant '{:?}' was added to ErrorKind, please handle it",
|
||||
error_kind
|
||||
),
|
||||
},
|
||||
_ => unreachable!(
|
||||
"A new variant '{:?}' was added to RetryKind, please handle it",
|
||||
retry_kind
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
match self
|
||||
.inner
|
||||
.clone()
|
||||
.try_acquire_many_owned(number_of_tokens_to_acquire)
|
||||
{
|
||||
Ok(permit) => Ok(token::Standard::new(permit)),
|
||||
Err(TryAcquireError::NoPermits) => Err(RateLimitingError::no_tokens()),
|
||||
Err(other) => Err(RateLimitingError::bug(other.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn available(&self) -> usize {
|
||||
self.inner.available_permits()
|
||||
}
|
||||
|
||||
fn refill(&self, tokens: usize) {
|
||||
// Ensure the bucket doesn't overflow by limiting the amount of tokens to add, if necessary.
|
||||
let amount_to_add = (self.available() + tokens).min(self.max_tokens) - self.available();
|
||||
if amount_to_add > 0 {
|
||||
self.inner.add_permits(amount_to_add)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{Token, TokenBucket};
|
||||
use super::{
|
||||
STANDARD_INITIAL_RETRY_TOKENS, STANDARD_RETRYABLE_ERROR_RETRY_COST,
|
||||
STANDARD_TIMEOUT_ERROR_RETRY_COST,
|
||||
};
|
||||
use aws_smithy_types::retry::{ErrorKind, RetryKind};
|
||||
|
||||
#[test]
|
||||
fn bucket_works() {
|
||||
let bucket = super::Standard::builder().build();
|
||||
assert_eq!(bucket.available(), STANDARD_INITIAL_RETRY_TOKENS);
|
||||
|
||||
let token = bucket
|
||||
.try_acquire(Some(RetryKind::Error(ErrorKind::ServerError)))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
bucket.available(),
|
||||
STANDARD_INITIAL_RETRY_TOKENS - STANDARD_RETRYABLE_ERROR_RETRY_COST as usize
|
||||
);
|
||||
Box::new(token).release();
|
||||
|
||||
let token = bucket
|
||||
.try_acquire(Some(RetryKind::Error(ErrorKind::TransientError)))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
bucket.available(),
|
||||
STANDARD_INITIAL_RETRY_TOKENS - STANDARD_TIMEOUT_ERROR_RETRY_COST as usize
|
||||
);
|
||||
Box::new(token).forget();
|
||||
assert_eq!(
|
||||
bucket.available(),
|
||||
STANDARD_INITIAL_RETRY_TOKENS - STANDARD_TIMEOUT_ERROR_RETRY_COST as usize
|
||||
);
|
||||
|
||||
bucket.refill(STANDARD_TIMEOUT_ERROR_RETRY_COST as usize);
|
||||
assert_eq!(bucket.available(), STANDARD_INITIAL_RETRY_TOKENS);
|
||||
}
|
||||
}
|
|
@ -139,7 +139,7 @@ async fn try_op(
|
|||
{
|
||||
let request_serializer = cfg.request_serializer();
|
||||
let input = ctx.take_input().expect("input set at this point");
|
||||
let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg));
|
||||
let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg).map_err(OrchestratorError::other));
|
||||
ctx.set_request(request);
|
||||
}
|
||||
|
||||
|
@ -171,10 +171,10 @@ async fn try_op(
|
|||
// No, this request shouldn't be sent
|
||||
Ok(ShouldAttempt::No) => {
|
||||
let err: BoxError = "the retry strategy indicates that an initial request shouldn't be made, but it didn't specify why".into();
|
||||
halt!([ctx] => err);
|
||||
halt!([ctx] => OrchestratorError::other(err));
|
||||
}
|
||||
// No, we shouldn't make a request because...
|
||||
Err(err) => halt!([ctx] => err),
|
||||
Err(err) => halt!([ctx] => OrchestratorError::other(err)),
|
||||
Ok(ShouldAttempt::YesAfterDelay(_)) => {
|
||||
unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.")
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ async fn try_op(
|
|||
// Save a request checkpoint before we make the request. This will allow us to "rewind"
|
||||
// the request in the case of retry attempts.
|
||||
ctx.save_checkpoint();
|
||||
for i in 0usize.. {
|
||||
for i in 1usize.. {
|
||||
debug!("beginning attempt #{i}");
|
||||
// Break from the loop if we can't rewind the request's state. This will always succeed the
|
||||
// first time, but will fail on subsequent iterations if the request body wasn't retryable.
|
||||
|
@ -201,19 +201,21 @@ async fn try_op(
|
|||
}
|
||||
.maybe_timeout_with_config(attempt_timeout_config)
|
||||
.await
|
||||
.map_err(OrchestratorError::other);
|
||||
.map_err(|err| OrchestratorError::timeout(err.into_source().unwrap()));
|
||||
|
||||
// We continue when encountering a timeout error. The retry classifier will decide what to do with it.
|
||||
continue_on_err!([ctx] => maybe_timeout);
|
||||
|
||||
let retry_strategy = cfg.retry_strategy();
|
||||
|
||||
// If we got a retry strategy from the bag, ask it what to do.
|
||||
// If no strategy was set, we won't retry.
|
||||
let should_attempt = halt_on_err!(
|
||||
[ctx] => retry_strategy
|
||||
.map(|rs| rs.should_attempt_retry(ctx, cfg))
|
||||
.unwrap_or(Ok(ShouldAttempt::No)
|
||||
));
|
||||
let should_attempt = match retry_strategy {
|
||||
Some(retry_strategy) => halt_on_err!(
|
||||
[ctx] => retry_strategy.should_attempt_retry(ctx, cfg).map_err(OrchestratorError::other)
|
||||
),
|
||||
None => ShouldAttempt::No,
|
||||
};
|
||||
match should_attempt {
|
||||
// Yes, let's retry the request
|
||||
ShouldAttempt::Yes => continue,
|
||||
|
@ -241,11 +243,11 @@ async fn try_attempt(
|
|||
stop_point: StopPoint,
|
||||
) {
|
||||
halt_on_err!([ctx] => interceptors.read_before_attempt(ctx, cfg));
|
||||
halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg));
|
||||
halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg).map_err(OrchestratorError::other));
|
||||
halt_on_err!([ctx] => interceptors.modify_before_signing(ctx, cfg));
|
||||
halt_on_err!([ctx] => interceptors.read_before_signing(ctx, cfg));
|
||||
|
||||
halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await);
|
||||
halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await.map_err(OrchestratorError::other));
|
||||
|
||||
halt_on_err!([ctx] => interceptors.read_after_signing(ctx, cfg));
|
||||
halt_on_err!([ctx] => interceptors.modify_before_transmit(ctx, cfg));
|
||||
|
@ -261,7 +263,12 @@ async fn try_attempt(
|
|||
ctx.enter_transmit_phase();
|
||||
let call_result = halt_on_err!([ctx] => {
|
||||
let request = ctx.take_request().expect("set during serialization");
|
||||
cfg.connection().call(request).await
|
||||
cfg.connection().call(request).await.map_err(|err| {
|
||||
match err.downcast() {
|
||||
Ok(connector_error) => OrchestratorError::connector(*connector_error),
|
||||
Err(box_err) => OrchestratorError::other(box_err)
|
||||
}
|
||||
})
|
||||
});
|
||||
ctx.set_response(call_result);
|
||||
ctx.enter_before_deserialization_phase();
|
||||
|
@ -279,7 +286,7 @@ async fn try_attempt(
|
|||
None => read_body(response)
|
||||
.instrument(debug_span!("read_body"))
|
||||
.await
|
||||
.map_err(OrchestratorError::other)
|
||||
.map_err(OrchestratorError::response)
|
||||
.and_then(|_| response_deserializer.deserialize_nonstreaming(response)),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
use aws_smithy_http::result::SdkError;
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
|
||||
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
|
||||
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
|
||||
use std::borrow::Cow;
|
||||
|
@ -76,17 +76,17 @@ where
|
|||
Ok(_) => return None,
|
||||
Err(err) => err,
|
||||
};
|
||||
// Check that the error is an operation error
|
||||
let error = error.as_operation_error()?;
|
||||
// Downcast the error
|
||||
let error = error.downcast_ref::<SdkError<E>>()?;
|
||||
|
||||
match error {
|
||||
SdkError::TimeoutError(_) => Some(RetryReason::Error(ErrorKind::TransientError)),
|
||||
SdkError::ResponseError { .. } => Some(RetryReason::Error(ErrorKind::TransientError)),
|
||||
SdkError::DispatchFailure(err) if (err.is_timeout() || err.is_io()) => {
|
||||
OrchestratorError::Response { .. } | OrchestratorError::Timeout { .. } => {
|
||||
Some(RetryReason::Error(ErrorKind::TransientError))
|
||||
}
|
||||
SdkError::DispatchFailure(err) => err.is_other().map(RetryReason::Error),
|
||||
OrchestratorError::Connector { err } if err.is_timeout() || err.is_io() => {
|
||||
Some(RetryReason::Error(ErrorKind::TransientError))
|
||||
}
|
||||
OrchestratorError::Connector { err } if err.is_other().is_some() => {
|
||||
err.is_other().map(RetryReason::Error)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
@ -152,8 +152,6 @@ mod test {
|
|||
HttpStatusCodeClassifier, ModeledAsRetryableClassifier,
|
||||
};
|
||||
use aws_smithy_http::body::SdkBody;
|
||||
use aws_smithy_http::operation;
|
||||
use aws_smithy_http::result::SdkError;
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
|
||||
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
|
||||
|
@ -242,11 +240,10 @@ mod test {
|
|||
#[test]
|
||||
fn classify_response_error() {
|
||||
let policy = SmithyErrorClassifier::<UnmodeledError>::new();
|
||||
let test_response = http::Response::new("OK").map(SdkBody::from);
|
||||
let err: SdkError<UnmodeledError> =
|
||||
SdkError::response_error(UnmodeledError, operation::Response::new(test_response));
|
||||
let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
|
||||
ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err))));
|
||||
ctx.set_output_or_error(Err(OrchestratorError::response(
|
||||
"I am a response error".into(),
|
||||
)));
|
||||
assert_eq!(
|
||||
policy.classify_retry(&ctx),
|
||||
Some(RetryReason::Error(ErrorKind::TransientError)),
|
||||
|
@ -256,9 +253,10 @@ mod test {
|
|||
#[test]
|
||||
fn test_timeout_error() {
|
||||
let policy = SmithyErrorClassifier::<UnmodeledError>::new();
|
||||
let err: SdkError<UnmodeledError> = SdkError::timeout_error("blah");
|
||||
let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
|
||||
ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err))));
|
||||
ctx.set_output_or_error(Err(OrchestratorError::timeout(
|
||||
"I am a timeout error".into(),
|
||||
)));
|
||||
assert_eq!(
|
||||
policy.classify_retry(&ctx),
|
||||
Some(RetryReason::Error(ErrorKind::TransientError)),
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#[cfg(feature = "test-util")]
|
||||
mod fixed_delay;
|
||||
mod never;
|
||||
mod standard;
|
||||
pub(crate) mod standard;
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
pub use fixed_delay::FixedDelayRetryStrategy;
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
use crate::client::retries::strategy::standard::ReleaseResult::{
|
||||
APermitWasReleased, NoPermitWasReleased,
|
||||
};
|
||||
use crate::client::runtime_plugin::standard_token_bucket::StandardTokenBucket;
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors};
|
||||
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
|
||||
|
@ -11,16 +15,21 @@ use aws_smithy_runtime_api::client::retries::{
|
|||
};
|
||||
use aws_smithy_types::config_bag::ConfigBag;
|
||||
use aws_smithy_types::retry::RetryConfig;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
|
||||
// The initial attempt, plus three retries.
|
||||
const DEFAULT_MAX_ATTEMPTS: usize = 4;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StandardRetryStrategy {
|
||||
max_attempts: usize,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
// Retry settings
|
||||
base: fn() -> f64,
|
||||
initial_backoff: Duration,
|
||||
max_attempts: usize,
|
||||
max_backoff: Duration,
|
||||
retry_permit: Mutex<Option<OwnedSemaphorePermit>>,
|
||||
}
|
||||
|
||||
impl StandardRetryStrategy {
|
||||
|
@ -45,6 +54,36 @@ impl StandardRetryStrategy {
|
|||
self.initial_backoff = initial_backoff;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
|
||||
self.max_backoff = max_backoff;
|
||||
self
|
||||
}
|
||||
|
||||
fn release_retry_permit(&self) -> ReleaseResult {
|
||||
let mut retry_permit = self.retry_permit.lock().unwrap();
|
||||
match retry_permit.take() {
|
||||
Some(p) => {
|
||||
drop(p);
|
||||
APermitWasReleased
|
||||
}
|
||||
None => NoPermitWasReleased,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) {
|
||||
let mut old_retry_permit = self.retry_permit.lock().unwrap();
|
||||
if let Some(p) = old_retry_permit.replace(new_retry_permit) {
|
||||
// Whenever we set a new retry permit and it replaces the old one, we need to "forget"
|
||||
// the old permit, removing it from the bucket forever.
|
||||
p.forget()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum ReleaseResult {
|
||||
APermitWasReleased,
|
||||
NoPermitWasReleased,
|
||||
}
|
||||
|
||||
impl Default for StandardRetryStrategy {
|
||||
|
@ -55,13 +94,14 @@ impl Default for StandardRetryStrategy {
|
|||
// by default, use a random base for exponential backoff
|
||||
base: fastrand::f64,
|
||||
initial_backoff: Duration::from_secs(1),
|
||||
retry_permit: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryStrategy for StandardRetryStrategy {
|
||||
// TODO(token-bucket) add support for optional cross-request token bucket
|
||||
fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> {
|
||||
// The standard token bucket is only ever considered for retry requests.
|
||||
Ok(ShouldAttempt::Yes)
|
||||
}
|
||||
|
||||
|
@ -74,18 +114,31 @@ impl RetryStrategy for StandardRetryStrategy {
|
|||
let output_or_error = ctx.output_or_error().expect(
|
||||
"This must never be called without reaching the point where the result exists.",
|
||||
);
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>();
|
||||
if output_or_error.is_ok() {
|
||||
tracing::debug!("request succeeded, no retry necessary");
|
||||
if let Some(tb) = token_bucket {
|
||||
// If this retry strategy is holding any permits, release them back to the bucket.
|
||||
if let NoPermitWasReleased = self.release_retry_permit() {
|
||||
// In the event that there was no retry permit to release, we generate new
|
||||
// permits from nothing. We do this to make up for permits we had to "forget".
|
||||
// Otherwise, repeated retries would empty the bucket and nothing could fill it
|
||||
// back up again.
|
||||
tb.regenerate_a_token();
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(ShouldAttempt::No);
|
||||
}
|
||||
|
||||
// Check if we're out of attempts
|
||||
let request_attempts: &RequestAttempts = cfg
|
||||
.get()
|
||||
.expect("at least one request attempt is made before any retry is attempted");
|
||||
if request_attempts.attempts() >= self.max_attempts {
|
||||
let request_attempts = cfg
|
||||
.get::<RequestAttempts>()
|
||||
.expect("at least one request attempt is made before any retry is attempted")
|
||||
.attempts();
|
||||
if request_attempts >= self.max_attempts {
|
||||
tracing::trace!(
|
||||
attempts = request_attempts.attempts(),
|
||||
attempts = request_attempts,
|
||||
max_attempts = self.max_attempts,
|
||||
"not retrying because we are out of attempts"
|
||||
);
|
||||
|
@ -95,9 +148,24 @@ impl RetryStrategy for StandardRetryStrategy {
|
|||
// Run the classifiers against the context to determine if we should retry
|
||||
let retry_classifiers = cfg.retry_classifiers();
|
||||
let retry_reason = retry_classifiers.classify_retry(ctx);
|
||||
|
||||
// Calculate the appropriate backoff time.
|
||||
let backoff = match retry_reason {
|
||||
Some(RetryReason::Explicit(dur)) => dur,
|
||||
Some(RetryReason::Error(_)) => {
|
||||
Some(RetryReason::Error(kind)) => {
|
||||
// If a token bucket was set, and the RetryReason IS NOT explicit, attempt to acquire a retry permit.
|
||||
if let Some(tb) = token_bucket {
|
||||
match tb.acquire(&kind) {
|
||||
Some(permit) => self.set_retry_permit(permit),
|
||||
None => {
|
||||
tracing::debug!(
|
||||
"attempt #{request_attempts} failed with {kind:?}; However, no retry permits are available, so no retry will be attempted.",
|
||||
);
|
||||
return Ok(ShouldAttempt::No);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let backoff = calculate_exponential_backoff(
|
||||
// Generate a random base multiplier to create jitter
|
||||
(self.base)(),
|
||||
|
@ -105,16 +173,14 @@ impl RetryStrategy for StandardRetryStrategy {
|
|||
self.initial_backoff.as_secs_f64(),
|
||||
// `self.local.attempts` tracks number of requests made including the initial request
|
||||
// The initial attempt shouldn't count towards backoff calculations so we subtract it
|
||||
(request_attempts.attempts() - 1) as u32,
|
||||
(request_attempts - 1) as u32,
|
||||
);
|
||||
Duration::from_secs_f64(backoff).min(self.max_backoff)
|
||||
}
|
||||
Some(_) => {
|
||||
unreachable!("RetryReason is non-exhaustive. Therefore, we need to cover this unreachable case.")
|
||||
}
|
||||
Some(_) => unreachable!("RetryReason is non-exhaustive"),
|
||||
None => {
|
||||
tracing::trace!(
|
||||
attempts = request_attempts.attempts(),
|
||||
attempts = request_attempts,
|
||||
max_attempts = self.max_attempts,
|
||||
"encountered unretryable error"
|
||||
);
|
||||
|
@ -123,8 +189,7 @@ impl RetryStrategy for StandardRetryStrategy {
|
|||
};
|
||||
|
||||
tracing::debug!(
|
||||
"attempt {} failed with {:?}; retrying after {:?}",
|
||||
request_attempts.attempts(),
|
||||
"attempt #{request_attempts} failed with {:?}; retrying after {:?}",
|
||||
retry_reason.expect("the match statement above ensures this is not None"),
|
||||
backoff
|
||||
);
|
||||
|
@ -139,16 +204,23 @@ fn calculate_exponential_backoff(base: f64, initial_backoff: f64, retry_attempts
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ShouldAttempt, StandardRetryStrategy};
|
||||
use super::{calculate_exponential_backoff, ShouldAttempt, StandardRetryStrategy};
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError};
|
||||
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
|
||||
use aws_smithy_runtime_api::client::retries::{AlwaysRetry, RetryClassifiers, RetryStrategy};
|
||||
use aws_smithy_runtime_api::client::retries::{
|
||||
AlwaysRetry, ClassifyRetry, RetryClassifiers, RetryReason, RetryStrategy,
|
||||
};
|
||||
use aws_smithy_types::config_bag::{ConfigBag, Layer};
|
||||
use aws_smithy_types::retry::ErrorKind;
|
||||
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
|
||||
use aws_smithy_types::type_erasure::TypeErasedBox;
|
||||
use std::fmt;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
use crate::client::runtime_plugin::standard_token_bucket::StandardTokenBucket;
|
||||
|
||||
#[test]
|
||||
fn no_retry_necessary_for_ok_result() {
|
||||
let cfg = ConfigBag::base();
|
||||
|
@ -221,4 +293,351 @@ mod tests {
|
|||
.expect("method is infallible for this use");
|
||||
assert_eq!(ShouldAttempt::No, actual);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ServerError;
|
||||
impl fmt::Display for ServerError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "OperationError")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ServerError {}
|
||||
|
||||
impl ProvideErrorKind for ServerError {
|
||||
fn retryable_error_kind(&self) -> Option<ErrorKind> {
|
||||
Some(ErrorKind::ServerError)
|
||||
}
|
||||
|
||||
fn code(&self) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PresetReasonRetryClassifier {
|
||||
retry_reasons: Mutex<Vec<RetryReason>>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
impl PresetReasonRetryClassifier {
|
||||
fn new(mut retry_reasons: Vec<RetryReason>) -> Self {
|
||||
// We'll pop the retry_reasons in reverse order so we reverse the list to fix that.
|
||||
retry_reasons.reverse();
|
||||
Self {
|
||||
retry_reasons: Mutex::new(retry_reasons),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClassifyRetry for PresetReasonRetryClassifier {
|
||||
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
|
||||
if ctx.output_or_error().map(|it| it.is_ok()).unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut retry_reasons = self.retry_reasons.lock().unwrap();
|
||||
if retry_reasons.len() == 1 {
|
||||
Some(retry_reasons.first().unwrap().clone())
|
||||
} else {
|
||||
retry_reasons.pop()
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"Always returns a preset retry reason"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
fn setup_test(retry_reasons: Vec<RetryReason>) -> (ConfigBag, InterceptorContext) {
|
||||
let mut cfg = ConfigBag::base();
|
||||
cfg.interceptor_state().set_retry_classifiers(
|
||||
RetryClassifiers::new()
|
||||
.with_classifier(PresetReasonRetryClassifier::new(retry_reasons)),
|
||||
);
|
||||
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
|
||||
// This type doesn't matter b/c the classifier will just return whatever we tell it to.
|
||||
ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
|
||||
|
||||
(cfg, ctx)
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn eventual_success() {
|
||||
let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(5);
|
||||
cfg.interceptor_state().put(StandardTokenBucket::default());
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(1));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 495);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(2));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(2));
|
||||
assert_eq!(token_bucket.available_permits(), 490);
|
||||
|
||||
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(3));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
assert_eq!(token_bucket.available_permits(), 495);
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn no_more_attempts() {
|
||||
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(3);
|
||||
cfg.interceptor_state().put(StandardTokenBucket::default());
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(1));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 495);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(2));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(2));
|
||||
assert_eq!(token_bucket.available_permits(), 490);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(3));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
assert_eq!(token_bucket.available_permits(), 490);
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn no_quota() {
|
||||
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(5);
|
||||
cfg.interceptor_state().put(StandardTokenBucket::new(5));
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(1));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 0);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(2));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
assert_eq!(token_bucket.available_permits(), 0);
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn quota_replenishes_on_success() {
|
||||
let (mut cfg, mut ctx) = setup_test(vec![
|
||||
RetryReason::Error(ErrorKind::TransientError),
|
||||
RetryReason::Explicit(Duration::from_secs(1)),
|
||||
]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(5);
|
||||
cfg.interceptor_state().put(StandardTokenBucket::new(100));
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(1));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 90);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(2));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 90);
|
||||
|
||||
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(3));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
|
||||
assert_eq!(token_bucket.available_permits(), 100);
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn quota_replenishes_on_first_try_success() {
|
||||
const PERMIT_COUNT: usize = 20;
|
||||
let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::TransientError)]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(usize::MAX);
|
||||
cfg.interceptor_state()
|
||||
.put(StandardTokenBucket::new(PERMIT_COUNT));
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
let mut attempt = 1;
|
||||
|
||||
// Drain all available permits with failed attempts
|
||||
while token_bucket.available_permits() > 0 {
|
||||
// Draining should complete in 2 attempts
|
||||
if attempt > 2 {
|
||||
panic!("This test should have completed by now (drain)");
|
||||
}
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(attempt));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
|
||||
attempt += 1;
|
||||
}
|
||||
|
||||
// Forget the permit so that we can only refill by "success on first try".
|
||||
let permit = strategy.retry_permit.lock().unwrap().take().unwrap();
|
||||
permit.forget();
|
||||
|
||||
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
|
||||
|
||||
// Replenish permits until we get back to `PERMIT_COUNT`
|
||||
while token_bucket.available_permits() < PERMIT_COUNT {
|
||||
if attempt > 23 {
|
||||
panic!("This test should have completed by now (fillup)");
|
||||
}
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(attempt));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
attempt += 1;
|
||||
}
|
||||
|
||||
assert_eq!(attempt, 23);
|
||||
assert_eq!(token_bucket.available_permits(), PERMIT_COUNT);
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn backoff_timing() {
|
||||
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(5);
|
||||
cfg.interceptor_state().put(StandardTokenBucket::default());
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(1));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 495);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(2));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(2));
|
||||
assert_eq!(token_bucket.available_permits(), 490);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(3));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(4));
|
||||
assert_eq!(token_bucket.available_permits(), 485);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(4));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(8));
|
||||
assert_eq!(token_bucket.available_permits(), 480);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(5));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
assert_eq!(token_bucket.available_permits(), 480);
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-util")]
|
||||
#[test]
|
||||
fn max_backoff_time() {
|
||||
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
|
||||
let strategy = StandardRetryStrategy::default()
|
||||
.with_base(|| 1.0)
|
||||
.with_max_attempts(5)
|
||||
.with_initial_backoff(Duration::from_secs(1))
|
||||
.with_max_backoff(Duration::from_secs(3));
|
||||
cfg.interceptor_state().put(StandardTokenBucket::default());
|
||||
let token_bucket = cfg.get::<StandardTokenBucket>().unwrap().clone();
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(1));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(1));
|
||||
assert_eq!(token_bucket.available_permits(), 495);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(2));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(2));
|
||||
assert_eq!(token_bucket.available_permits(), 490);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(3));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(3));
|
||||
assert_eq!(token_bucket.available_permits(), 485);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(4));
|
||||
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
let dur = should_retry.expect_delay();
|
||||
assert_eq!(dur, Duration::from_secs(3));
|
||||
assert_eq!(token_bucket.available_permits(), 480);
|
||||
|
||||
cfg.interceptor_state().put(RequestAttempts::new(5));
|
||||
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
|
||||
assert_eq!(no_retry, ShouldAttempt::No);
|
||||
assert_eq!(token_bucket.available_permits(), 480);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculate_exponential_backoff_where_initial_backoff_is_one() {
|
||||
let initial_backoff = 1.0;
|
||||
|
||||
for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() {
|
||||
let actual_backoff =
|
||||
calculate_exponential_backoff(1.0, initial_backoff, attempt as u32);
|
||||
assert_eq!(expected_backoff, actual_backoff);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() {
|
||||
let initial_backoff = 3.0;
|
||||
|
||||
for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() {
|
||||
let actual_backoff =
|
||||
calculate_exponential_backoff(1.0, initial_backoff, attempt as u32);
|
||||
assert_eq!(expected_backoff, actual_backoff);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() {
|
||||
let initial_backoff = 0.03;
|
||||
|
||||
for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() {
|
||||
let actual_backoff =
|
||||
calculate_exponential_backoff(1.0, initial_backoff, attempt as u32);
|
||||
assert_eq!(expected_backoff, actual_backoff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,3 +5,5 @@
|
|||
|
||||
#[cfg(feature = "anonymous-auth")]
|
||||
pub mod anonymous_auth;
|
||||
|
||||
pub mod standard_token_bucket;
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
|
||||
use aws_smithy_types::config_bag::{FrozenLayer, Layer, Storable, StoreReplace};
|
||||
use aws_smithy_types::retry::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tracing::trace;
|
||||
|
||||
/// A [RuntimePlugin] to provide a standard token bucket, usable by the
|
||||
/// [`StandardRetryStrategy`](crate::client::retries::strategy::standard::StandardRetryStrategy).
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct StandardTokenBucketRuntimePlugin {
|
||||
token_bucket: StandardTokenBucket,
|
||||
}
|
||||
|
||||
impl StandardTokenBucketRuntimePlugin {
|
||||
pub fn new(initial_tokens: usize) -> Self {
|
||||
Self {
|
||||
token_bucket: StandardTokenBucket::new(initial_tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimePlugin for StandardTokenBucketRuntimePlugin {
|
||||
fn config(&self) -> Option<FrozenLayer> {
|
||||
let mut cfg = Layer::new("standard token bucket");
|
||||
cfg.store_put(self.token_bucket.clone());
|
||||
|
||||
Some(cfg.freeze())
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_CAPACITY: usize = 500;
|
||||
const RETRY_COST: u32 = 5;
|
||||
const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
|
||||
const PERMIT_REGENERATION_AMOUNT: usize = 1;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct StandardTokenBucket {
|
||||
semaphore: Arc<Semaphore>,
|
||||
max_permits: usize,
|
||||
timeout_retry_cost: u32,
|
||||
retry_cost: u32,
|
||||
}
|
||||
|
||||
impl Storable for StandardTokenBucket {
|
||||
type Storer = StoreReplace<Self>;
|
||||
}
|
||||
|
||||
impl Default for StandardTokenBucket {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
|
||||
max_permits: DEFAULT_CAPACITY,
|
||||
timeout_retry_cost: RETRY_TIMEOUT_COST,
|
||||
retry_cost: RETRY_COST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StandardTokenBucket {
|
||||
pub(crate) fn new(initial_quota: usize) -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(Semaphore::new(initial_quota)),
|
||||
max_permits: initial_quota,
|
||||
retry_cost: RETRY_COST,
|
||||
timeout_retry_cost: RETRY_TIMEOUT_COST,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
|
||||
let retry_cost = if err == &ErrorKind::TransientError {
|
||||
self.timeout_retry_cost
|
||||
} else {
|
||||
self.retry_cost
|
||||
};
|
||||
|
||||
self.semaphore
|
||||
.clone()
|
||||
.try_acquire_many_owned(retry_cost)
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub(crate) fn regenerate_a_token(&self) {
|
||||
if self.semaphore.available_permits() < (self.max_permits) {
|
||||
trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
|
||||
self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "test-util"))]
|
||||
pub(crate) fn available_permits(&self) -> usize {
|
||||
self.semaphore.available_permits()
|
||||
}
|
||||
}
|
|
@ -143,6 +143,7 @@ pub struct RetryConfigBuilder {
|
|||
mode: Option<RetryMode>,
|
||||
max_attempts: Option<u32>,
|
||||
initial_backoff: Option<Duration>,
|
||||
max_backoff: Option<Duration>,
|
||||
reconnect_mode: Option<ReconnectMode>,
|
||||
}
|
||||
|
||||
|
@ -212,6 +213,18 @@ impl RetryConfigBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
/// Set the max_backoff duration. This duration should be non-zero.
|
||||
pub fn set_max_backoff(&mut self, max_backoff: Option<Duration>) -> &mut Self {
|
||||
self.max_backoff = max_backoff;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the max_backoff duration. This duration should be non-zero.
|
||||
pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
|
||||
self.set_max_backoff(Some(max_backoff));
|
||||
self
|
||||
}
|
||||
|
||||
/// Merge two builders together. Values from `other` will only be used as a fallback for values
|
||||
/// from `self` Useful for merging configs from different sources together when you want to
|
||||
/// handle "precedence" per value instead of at the config level
|
||||
|
@ -233,6 +246,7 @@ impl RetryConfigBuilder {
|
|||
mode: self.mode.or(other.mode),
|
||||
max_attempts: self.max_attempts.or(other.max_attempts),
|
||||
initial_backoff: self.initial_backoff.or(other.initial_backoff),
|
||||
max_backoff: self.max_backoff.or(other.max_backoff),
|
||||
reconnect_mode: self.reconnect_mode.or(other.reconnect_mode),
|
||||
}
|
||||
}
|
||||
|
@ -248,6 +262,7 @@ impl RetryConfigBuilder {
|
|||
reconnect_mode: self
|
||||
.reconnect_mode
|
||||
.unwrap_or(ReconnectMode::ReconnectOnTransientError),
|
||||
max_backoff: self.max_backoff.unwrap_or_else(|| Duration::from_secs(20)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -259,6 +274,7 @@ pub struct RetryConfig {
|
|||
mode: RetryMode,
|
||||
max_attempts: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
reconnect_mode: ReconnectMode,
|
||||
}
|
||||
|
||||
|
@ -286,6 +302,7 @@ impl RetryConfig {
|
|||
max_attempts: 3,
|
||||
initial_backoff: Duration::from_secs(1),
|
||||
reconnect_mode: ReconnectMode::ReconnectOnTransientError,
|
||||
max_backoff: Duration::from_secs(20),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue