Add Retry support (#232)

* Add Retry support

* Fixup some broken tests

* Refactor, add docs, rename retry strategy

* Some more renames, some more docs
This commit is contained in:
Russell Cohen 2021-03-02 17:32:55 -05:00 committed by GitHub
parent 6eaae060dc
commit 8f844c579c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 735 additions and 62 deletions

View File

@ -55,6 +55,8 @@ jobs:
java-version: ${{ env.java_version }}
- name: test
run: ./gradlew :codegen:test
- name: aws tests
run: ./gradlew :aws:sdk-codegen:test
integration-tests:
name: Codegen integration tests
runs-on: ubuntu-latest

View File

@ -1,4 +1,6 @@
pub mod user_agent;
use smithy_http::result::SdkError;
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use std::time::Duration;
@ -11,6 +13,7 @@ use std::time::Duration;
/// 3. The code is checked against a predetermined list of throttling errors & transient error codes
/// 4. The status code is checked against a predetermined list of status codes
#[non_exhaustive]
#[derive(Clone)]
pub struct AwsErrorRetryPolicy;
const TRANSIENT_ERROR_STATUS_CODES: [u16; 2] = [400, 408];
@ -45,11 +48,16 @@ impl Default for AwsErrorRetryPolicy {
}
}
impl ClassifyResponse for AwsErrorRetryPolicy {
fn classify<E, B>(&self, err: E, response: &http::Response<B>) -> RetryKind
where
E: ProvideErrorKind,
{
impl<T, E, B> ClassifyResponse<T, SdkError<E, B>> for AwsErrorRetryPolicy
where
E: ProvideErrorKind,
{
fn classify(&self, err: Result<&T, &SdkError<E, B>>) -> RetryKind {
let (err, response) = match err {
Ok(_) => return RetryKind::NotRetryable,
Err(SdkError::ServiceError { err, raw }) => (err, raw),
Err(_) => return RetryKind::NotRetryable,
};
if let Some(retry_after_delay) = response
.headers()
.get("x-amz-retry-after")
@ -80,6 +88,7 @@ impl ClassifyResponse for AwsErrorRetryPolicy {
#[cfg(test)]
mod test {
use crate::AwsErrorRetryPolicy;
use smithy_http::result::{SdkError, SdkSuccess};
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use std::time::Duration;
@ -110,12 +119,16 @@ mod test {
}
}
fn make_err<E, B>(err: E, raw: http::Response<B>) -> Result<SdkSuccess<(), B>, SdkError<E, B>> {
Err(SdkError::ServiceError { err, raw })
}
#[test]
fn not_an_error() {
let policy = AwsErrorRetryPolicy::new();
let test_response = http::Response::new("OK");
assert_eq!(
policy.classify(UnmodeledError, &test_response),
policy.classify(make_err(UnmodeledError, test_response).as_ref()),
RetryKind::NotRetryable
);
}
@ -128,7 +141,7 @@ mod test {
.body("error!")
.unwrap();
assert_eq!(
policy.classify(UnmodeledError, &test_resp),
policy.classify(make_err(UnmodeledError, test_resp).as_ref()),
RetryKind::Error(ErrorKind::TransientError)
);
}
@ -139,16 +152,20 @@ mod test {
let policy = AwsErrorRetryPolicy::new();
assert_eq!(
policy.classify(CodedError { code: "Throttling" }, &test_response),
policy.classify(make_err(CodedError { code: "Throttling" }, test_response).as_ref()),
RetryKind::Error(ErrorKind::ThrottlingError)
);
let test_response = http::Response::new("OK");
assert_eq!(
policy.classify(
CodedError {
code: "RequestTimeout"
},
&test_response,
make_err(
CodedError {
code: "RequestTimeout"
},
test_response
)
.as_ref()
),
RetryKind::Error(ErrorKind::TransientError)
)
@ -164,7 +181,7 @@ mod test {
let test_response = http::Response::new("OK");
let policy = AwsErrorRetryPolicy::new();
assert_eq!(
policy.classify(err, &test_response),
policy.classify(make_err(err, test_response).as_ref()),
RetryKind::Error(ErrorKind::ThrottlingError)
);
}
@ -187,7 +204,7 @@ mod test {
let policy = AwsErrorRetryPolicy::new();
assert_eq!(
policy.classify(ModeledRetries, &test_response),
policy.classify(make_err(ModeledRetries, test_response).as_ref()),
RetryKind::Error(ErrorKind::ClientError)
);
}
@ -201,7 +218,7 @@ mod test {
.unwrap();
assert_eq!(
policy.classify(UnmodeledError, &test_response),
policy.classify(make_err(UnmodeledError, test_response).as_ref()),
RetryKind::Explicit(Duration::from_millis(5000))
);
}

View File

@ -8,7 +8,7 @@ edition = "2018"
[dependencies]
hyper = { version = "0.14.2", features = ["client", "http1", "http2", "tcp", "runtime"] }
tower = { version = "0.4.6", features = ["util"] }
tower = { version = "0.4.6", features = ["util", "retry"] }
hyper-tls = "0.5.0"
aws-auth = { path = "../aws-auth" }
aws-sig-auth = { path = "../aws-sig-auth" }
@ -18,9 +18,12 @@ http = "0.2.3"
bytes = "1"
http-body = "0.4.0"
smithy-http = { path = "../../../rust-runtime/smithy-http" }
smithy-types = { path = "../../../rust-runtime/smithy-types" }
smithy-http-tower = { path = "../../../rust-runtime/smithy-http-tower" }
fastrand = "1.4.0"
tokio = { version = "1", features = ["time"]}
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["full", "test-util"] }
tower-test = "0.4.0"
aws-types = { path = "../aws-types" }

View File

@ -1,5 +1,8 @@
mod retry;
pub mod test_connection;
pub use retry::RetryConfig;
use crate::retry::RetryHandlerFactory;
use aws_endpoint::AwsEndpointStage;
use aws_http::user_agent::UserAgentStage;
use aws_sig_auth::middleware::SigV4SigningStage;
@ -10,9 +13,11 @@ use hyper_tls::HttpsConnector;
use smithy_http::body::SdkBody;
use smithy_http::operation::Operation;
use smithy_http::response::ParseHttpResponse;
use smithy_http::retry::ClassifyResponse;
use smithy_http_tower::dispatch::DispatchLayer;
use smithy_http_tower::map_request::MapRequestLayer;
use smithy_http_tower::parse_response::ParseResponseLayer;
use smithy_types::retry::ProvideErrorKind;
use std::error::Error;
use tower::{Service, ServiceBuilder, ServiceExt};
@ -39,12 +44,21 @@ pub type SdkSuccess<T> = smithy_http::result::SdkSuccess<T, hyper::Body>;
pub struct Client<S> {
inner: S,
retry_strategy: RetryHandlerFactory,
}
impl<S> Client<S> {
/// Construct a new `Client` with a custom connector
pub fn new(connector: S) -> Self {
Client { inner: connector }
Client {
inner: connector,
retry_strategy: RetryHandlerFactory::new(RetryConfig::default()),
}
}
pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
self.retry_strategy.with_config(retry_config);
self
}
}
@ -53,7 +67,10 @@ impl Client<hyper::Client<HttpsConnector<HttpConnector>, SdkBody>> {
pub fn https() -> Self {
let https = HttpsConnector::new();
let client = HyperClient::builder().build::<_, SdkBody>(https);
Client { inner: client }
Client {
inner: client,
retry_strategy: RetryHandlerFactory::new(RetryConfig::default()),
}
}
}
@ -72,8 +89,9 @@ where
/// access the raw response use `call_raw`.
pub async fn call<O, T, E, Retry>(&self, input: Operation<O, Retry>) -> Result<T, SdkError<E>>
where
O: ParseHttpResponse<hyper::Body, Output = Result<T, E>> + Send + 'static,
E: Error,
O: ParseHttpResponse<hyper::Body, Output = Result<T, E>> + Send + Clone + 'static,
E: Error + ProvideErrorKind,
Retry: ClassifyResponse<SdkSuccess<T>, SdkError<E>>,
{
self.call_raw(input).await.map(|res| res.parsed)
}
@ -87,14 +105,17 @@ where
input: Operation<O, Retry>,
) -> Result<SdkSuccess<R>, SdkError<E>>
where
O: ParseHttpResponse<hyper::Body, Output = Result<R, E>> + Send + 'static,
E: Error,
O: ParseHttpResponse<hyper::Body, Output = Result<R, E>> + Send + Clone + 'static,
E: Error + ProvideErrorKind,
Retry: ClassifyResponse<SdkSuccess<R>, SdkError<E>>,
{
let signer = MapRequestLayer::for_mapper(SigV4SigningStage::new(SigV4Signer::new()));
let endpoint_resolver = MapRequestLayer::for_mapper(AwsEndpointStage);
let user_agent = MapRequestLayer::for_mapper(UserAgentStage::new());
let inner = self.inner.clone();
let mut svc = ServiceBuilder::new()
// Create a new request-scoped policy
.retry(self.retry_strategy.new_handler())
.layer(ParseResponseLayer::<O, Retry>::new())
.layer(endpoint_resolver)
.layer(signer)

View File

@ -0,0 +1,407 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
//! Retry support for aws-hyper
//!
//! The actual retry policy implementation will likely be replaced
//! with the CRT implementation once the bindings exist. This
//! implementation is intended to be _correct_ but not especially long lasting.
//!
//! Components:
//! - [`RetryHandlerFactory`](crate::retry::RetryHandlerFactory): Top level manager, intended
//! to be associated with a [`Client`](crate::Client). Its sole purpose in life is to create a RetryHandler
//! for individual requests.
//! - [`RetryHandler`](crate::retry::RetryHandler): A request-scoped retry policy,
//! backed by request-local state and shared state contained within [`RetryHandlerFactory`](crate::retry::RetryHandlerFactory)
//! - [`RetryConfig`](crate::retry::RetryConfig): Static configuration (max retries, max backoff etc.)
use crate::{SdkError, SdkSuccess};
use smithy_http::operation;
use smithy_http::operation::Operation;
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
/// Retry Policy Configuration
///
/// Without specific use cases, users should generally rely on the default values set by `[RetryConfig::default]`(RetryConfig::default).`
///
/// Currently these fields are private and no setters provided. As needed, this configuration will become user-modifiable in the future..
#[derive(Clone)]
pub struct RetryConfig {
initial_retry_tokens: usize,
retry_cost: usize,
no_retry_increment: usize,
timeout_retry_cost: usize,
max_retries: u32,
max_backoff: Duration,
base: fn() -> f64,
}
impl RetryConfig {
/// Override `b` in the exponential backoff computation
///
/// By default, `base` is a randomly generated value between 0 and 1. In tests, it can
/// be helpful to override this:
/// ```rust
/// use aws_hyper::RetryConfig;
/// let conf = RetryConfig::default().with_base(||1_f64);
/// ```
pub fn with_base(mut self, base: fn() -> f64) -> Self {
self.base = base;
self
}
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
initial_retry_tokens: INITIAL_RETRY_TOKENS,
retry_cost: RETRY_COST,
no_retry_increment: 1,
timeout_retry_cost: 10,
max_retries: MAX_RETRIES,
max_backoff: Duration::from_secs(20),
// by default, use a random base for exponential backoff
base: fastrand::f64,
}
}
}
const MAX_RETRIES: u32 = 3;
const INITIAL_RETRY_TOKENS: usize = 500;
const RETRY_COST: usize = 5;
/// Manage retries for a service
///
/// An implementation of the `standard` AWS retry strategy as specified in the SEP. A `Strategy` is scoped to a client.
/// For an individual request, call [`RetryHandlerFactory::new_handler()`](RetryHandlerFactory::new_handler)
///
/// In the future, adding support for the adaptive retry strategy will be added by adding a `TokenBucket` to
/// `CrossRequestRetryState`
/// Its main functionality is via `new_handler` which creates a `RetryHandler` to manage the retry for
/// an individual request.
pub struct RetryHandlerFactory {
config: RetryConfig,
shared_state: CrossRequestRetryState,
}
impl RetryHandlerFactory {
pub fn new(config: RetryConfig) -> Self {
Self {
shared_state: CrossRequestRetryState::new(config.initial_retry_tokens),
config,
}
}
pub fn with_config(&mut self, config: RetryConfig) {
self.config = config;
}
pub(crate) fn new_handler(&self) -> RetryHandler {
RetryHandler {
local: RequestLocalRetryState::new(),
shared: self.shared_state.clone(),
config: self.config.clone(),
}
}
}
#[derive(Default, Clone)]
struct RequestLocalRetryState {
attempts: u32,
last_quota_usage: Option<usize>,
}
impl RequestLocalRetryState {
pub fn new() -> Self {
Self::default()
}
}
/* TODO in followup PR:
/// RetryPartition represents a scope for cross request retry state
///
/// For example, a retry partition could be the id of a service. This would give each service a separate retry budget.
struct RetryPartition(Cow<'static, str>); */
/// Shared state between multiple requests to the same client.
#[derive(Clone)]
struct CrossRequestRetryState {
quota_available: Arc<Mutex<usize>>,
}
// clippy is upset that we didn't use AtomicUsize here, but doing so makes the code
// significantly more complicated for negligible benefit.
#[allow(clippy::mutex_atomic)]
impl CrossRequestRetryState {
pub fn new(initial_quota: usize) -> Self {
Self {
quota_available:
Arc::new(Mutex::new(initial_quota)),
}
}
fn quota_release(&self, value: Option<usize>, config: &RetryConfig) {
let mut quota = self.quota_available.lock().unwrap();
*quota += value.unwrap_or(config.no_retry_increment);
}
/// Attempt to acquire retry quota for `ErrorKind`
///
/// If quota is available, the amount of quota consumed is returned
/// If no quota is available, `None` is returned.
fn quota_acquire(&self, err: &ErrorKind, config: &RetryConfig) -> Option<usize> {
let mut quota = self.quota_available.lock().unwrap();
let retry_cost = if err == &ErrorKind::TransientError {
config.timeout_retry_cost
} else {
config.retry_cost
};
if retry_cost > *quota {
None
} else {
*quota -= retry_cost;
Some(retry_cost)
}
}
}
/// RetryHandler
///
/// Implement retries for an individual request.
/// It is intended to be used as a [Tower Retry Policy](tower::retry::Policy) for use in tower-based
/// middleware stacks.
#[derive(Clone)]
pub(crate) struct RetryHandler {
local: RequestLocalRetryState,
shared: CrossRequestRetryState,
config: RetryConfig,
}
#[cfg(test)]
impl RetryHandler {
fn retry_quota(&self) -> usize {
*self.shared.quota_available.lock().unwrap()
}
}
impl RetryHandler {
/// Determine the correct response given `retry_kind`
///
/// If a retry is specified, this function returns `(next, backoff_duration)`
/// If no retry is specified, this function returns None
pub fn attempt_retry(&self, retry_kind: Result<(), ErrorKind>) -> Option<(Self, Duration)> {
let quota_used = match retry_kind {
Ok(_) => {
self.shared
.quota_release(self.local.last_quota_usage, &self.config);
return None;
}
Err(e) => {
if self.local.attempts == self.config.max_retries - 1 {
return None;
}
self.shared.quota_acquire(&e, &self.config)?
}
};
/*
From the retry spec:
b = random number within the range of: 0 <= b <= 1
r = 2
t_i = min(br^i, MAX_BACKOFF);
*/
let r: i32 = 2;
let b = (self.config.base)();
let backoff = b * (r.pow(self.local.attempts) as f64);
let backoff = Duration::from_secs_f64(backoff).min(self.config.max_backoff);
let next = RetryHandler {
local: RequestLocalRetryState {
attempts: self.local.attempts + 1,
last_quota_usage: Some(quota_used),
},
shared: self.shared.clone(),
config: self.config.clone(),
};
Some((next, backoff))
}
}
impl<Handler, R, T, E>
tower::retry::Policy<operation::Operation<Handler, R>, SdkSuccess<T>, SdkError<E>>
for RetryHandler
where
E: ProvideErrorKind,
Handler: Clone,
R: ClassifyResponse<SdkSuccess<T>, SdkError<E>>,
{
type Future = Pin<Box<dyn Future<Output = Self>>>;
fn retry(
&self,
req: &Operation<Handler, R>,
result: Result<&SdkSuccess<T>, &SdkError<E>>,
) -> Option<Self::Future> {
let policy = req.retry_policy();
let retry = policy.classify(result);
let (next, fut) = match retry {
RetryKind::Explicit(dur) => (self.clone(), dur),
RetryKind::NotRetryable => return None,
RetryKind::Error(err) => self.attempt_retry(Err(err))?,
_ => return None,
};
let fut = async move {
tokio::time::sleep(fut).await;
next
};
Some(Box::pin(fut))
}
fn clone_request(&self, req: &Operation<Handler, R>) -> Option<Operation<Handler, R>> {
req.try_clone()
}
}
#[cfg(test)]
mod test {
use crate::retry::{
RetryConfig, RetryHandlerFactory,
};
use smithy_types::retry::ErrorKind;
use std::time::Duration;
fn test_config() -> RetryConfig {
RetryConfig::default().with_base(|| 1_f64)
}
#[test]
fn eventual_success() {
let policy = RetryHandlerFactory::new(test_config()).new_handler();
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(policy.retry_quota(), 495);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(policy.retry_quota(), 490);
let no_retry = policy.attempt_retry(Ok(()));
assert!(no_retry.is_none());
assert_eq!(policy.retry_quota(), 495);
}
#[test]
fn no_more_attempts() {
let policy = RetryHandlerFactory::new(test_config()).new_handler();
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(policy.retry_quota(), 495);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(policy.retry_quota(), 490);
let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
assert!(no_retry.is_none());
assert_eq!(policy.retry_quota(), 490);
}
#[test]
fn no_quota() {
let mut conf = test_config();
conf.initial_retry_tokens = 5;
let policy = RetryHandlerFactory::new(conf).new_handler();
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(policy.retry_quota(), 0);
let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
assert!(no_retry.is_none());
assert_eq!(policy.retry_quota(), 0);
}
#[test]
fn backoff_timing() {
let mut conf = test_config();
conf.max_retries = 5;
let policy = RetryHandlerFactory::new(conf).new_handler();
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(policy.retry_quota(), 495);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(policy.retry_quota(), 490);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(4));
assert_eq!(policy.retry_quota(), 485);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(8));
assert_eq!(policy.retry_quota(), 480);
let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
assert!(no_retry.is_none());
assert_eq!(policy.retry_quota(), 480);
}
#[test]
fn max_backoff_time() {
let mut conf = test_config();
conf.max_retries = 5;
conf.max_backoff = Duration::from_secs(3);
let policy = RetryHandlerFactory::new(conf).new_handler();
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(policy.retry_quota(), 495);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(policy.retry_quota(), 490);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(3));
assert_eq!(policy.retry_quota(), 485);
let (policy, dur) = policy
.attempt_retry(Err(ErrorKind::ServerError))
.expect("should retry");
assert_eq!(dur, Duration::from_secs(3));
assert_eq!(policy.retry_quota(), 480);
let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
assert!(no_retry.is_none());
assert_eq!(policy.retry_quota(), 480);
}
}

View File

@ -6,8 +6,9 @@
use aws_auth::Credentials;
use aws_endpoint::{set_endpoint_resolver, DefaultAwsEndpointResolver};
use aws_http::user_agent::AwsUserAgent;
use aws_http::AwsErrorRetryPolicy;
use aws_hyper::test_connection::{TestConnection, ValidateRequest};
use aws_hyper::Client;
use aws_hyper::{Client, RetryConfig, SdkError};
use aws_sig_auth::signer::OperationSigningConfig;
use aws_types::region::Region;
use bytes::Bytes;
@ -17,18 +18,21 @@ use smithy_http::body::SdkBody;
use smithy_http::operation;
use smithy_http::operation::Operation;
use smithy_http::response::ParseHttpResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind};
use std::convert::Infallible;
use std::error::Error;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use std::time::{Duration, UNIX_EPOCH};
use tokio::time::Instant;
#[derive(Clone)]
struct TestOperationParser;
#[derive(Debug)]
struct OperationError;
impl Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
@ -37,6 +41,16 @@ impl Display for OperationError {
impl Error for OperationError {}
impl ProvideErrorKind for OperationError {
fn retryable_error_kind(&self) -> Option<ErrorKind> {
Some(ErrorKind::ThrottlingError)
}
fn code(&self) -> Option<&str> {
None
}
}
impl<B> ParseHttpResponse<B> for TestOperationParser
where
B: http_body::Body,
@ -44,7 +58,11 @@ where
type Output = Result<String, OperationError>;
fn parse_unloaded(&self, _response: &mut Response<B>) -> Option<Self::Output> {
Some(Ok("Hello!".to_string()))
if _response.status().is_success() {
Some(Ok("Hello!".to_string()))
} else {
Some(Err(OperationError))
}
}
fn parse_loaded(&self, _response: &Response<Bytes>) -> Self::Output {
@ -52,7 +70,7 @@ where
}
}
fn test_operation() -> Operation<TestOperationParser, ()> {
fn test_operation() -> Operation<TestOperationParser, AwsErrorRetryPolicy> {
let req = operation::Request::new(http::Request::new(SdkBody::from("request body")))
.augment(|req, mut conf| {
set_endpoint_resolver(
@ -70,7 +88,7 @@ fn test_operation() -> Operation<TestOperationParser, ()> {
Result::<_, Infallible>::Ok(req)
})
.unwrap();
Operation::new(req, TestOperationParser)
Operation::new(req, TestOperationParser).with_retry_policy(AwsErrorRetryPolicy::new())
}
#[tokio::test]
@ -102,3 +120,85 @@ async fn e2e_test() {
assert_eq!(actual.body().bytes(), expected.body().bytes());
assert_eq!(actual.uri(), expected.uri());
}
#[tokio::test]
async fn retry_test() {
fn req() -> http::Request<SdkBody> {
http::Request::builder()
.body(SdkBody::from("request body"))
.unwrap()
}
fn ok() -> http::Response<&'static str> {
http::Response::builder()
.status(200)
.body("response body")
.unwrap()
}
fn err() -> http::Response<&'static str> {
http::Response::builder()
.status(500)
.body("response body")
.unwrap()
}
// 1 failing response followed by 1 succesful response
let events = vec![
// First operation
(req(), err()),
(req(), err()),
(req(), ok()),
// Second operation
(req(), err()),
(req(), ok()),
// Third operation will fail, only errors
(req(), err()),
(req(), err()),
(req(), err()),
(req(), err()),
(req(), err()),
(req(), err()),
(req(), err()),
];
let conn = TestConnection::new(events);
let retry_config = RetryConfig::default().with_base(|| 1_f64);
let client = Client::new(conn.clone()).with_retry_config(retry_config);
tokio::time::pause();
let initial = tokio::time::Instant::now();
let resp = client
.call(test_operation())
.await
.expect("successful operation");
assert_time_passed(initial, Duration::from_secs(3));
assert_eq!(resp, "Hello!");
// 3 requests should have been made, 2 failing & one success
assert_eq!(conn.requests().len(), 3);
let initial = tokio::time::Instant::now();
client
.call(test_operation())
.await
.expect("successful operation");
assert_time_passed(initial, Duration::from_secs(1));
assert_eq!(conn.requests().len(), 5);
let initial = tokio::time::Instant::now();
let err = client
.call(test_operation())
.await
.expect_err("all responses failed");
// three more tries followed by failure
assert_eq!(conn.requests().len(), 8);
assert!(matches!(err, SdkError::ServiceError { .. }));
assert_time_passed(initial, Duration::from_secs(3));
}
/// Validate that time has passed with a 5ms tolerance
///
/// This is to account for some non-determinism in the Tokio timer
fn assert_time_passed(initial: Instant, passed: Duration) {
let now = tokio::time::Instant::now();
let delta = now - initial;
if (delta.as_millis() as i128 - passed.as_millis() as i128).abs() > 5 {
assert_eq!(delta, passed)
}
}

View File

@ -12,7 +12,8 @@ val DECORATORS = listOf(
RegionDecorator(),
AwsEndpointDecorator(),
UserAgentDecorator(),
SigV4SigningDecorator()
SigV4SigningDecorator(),
RetryPolicyDecorator()
)
class AwsCodegenDecorator : CombinedCodegenDecorator(DECORATORS) {

View File

@ -97,8 +97,7 @@ class EndpointResolverFeature(private val runtimeConfig: RuntimeConfig, private
OperationCustomization() {
override fun section(section: OperationSection): Writable {
return when (section) {
OperationSection.ImplBlock -> emptySection
is OperationSection.Feature -> writable {
is OperationSection.MutateRequest -> writable {
rust(
"""
#T::set_endpoint_resolver(&mut ${section.request}.config_mut(), ${section.config}.endpoint_resolver.clone());
@ -106,6 +105,7 @@ class EndpointResolverFeature(private val runtimeConfig: RuntimeConfig, private
runtimeConfig.awsEndpointDependency().asType()
)
}
else -> emptySection
}
}
}

View File

@ -90,7 +90,7 @@ class CredentialProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomizati
class CredentialsProviderFeature(private val runtimeConfig: RuntimeConfig) : OperationCustomization() {
override fun section(section: OperationSection): Writable {
return when (section) {
is OperationSection.Feature -> writable {
is OperationSection.MutateRequest -> writable {
rust(
"""
#T(&mut ${section.request}.config_mut(), ${section.config}.credentials_provider.clone());

View File

@ -107,8 +107,7 @@ class RegionProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomization()
class RegionConfigPlugin : OperationCustomization() {
override fun section(section: OperationSection): Writable {
return when (section) {
OperationSection.ImplBlock -> emptySection
is OperationSection.Feature -> writable {
is OperationSection.MutateRequest -> writable {
// Allow the region to be late-inserted via another method
rust(
"""
@ -118,6 +117,7 @@ class RegionConfigPlugin : OperationCustomization() {
"""
)
}
else -> emptySection
}
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
package software.amazon.smithy.rustsdk
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
class RetryPolicyDecorator : RustCodegenDecorator {
override val name: String = "RetryPolicy"
override val order: Byte = 0
override fun operationCustomizations(
protocolConfig: ProtocolConfig,
operation: OperationShape,
baseCustomizations: List<OperationCustomization>
): List<OperationCustomization> {
return baseCustomizations + RetryPolicyFeature(protocolConfig.runtimeConfig)
}
}
class RetryPolicyFeature(private val runtimeConfig: RuntimeConfig) : OperationCustomization() {
override fun retryType(): RuntimeType = runtimeConfig.awsHttp().asType().copy(name = "AwsErrorRetryPolicy")
override fun section(section: OperationSection) = when (section) {
is OperationSection.FinalizeOperation -> writable {
rust(
"let ${section.operation} = ${section.operation}.with_retry_policy(#T::AwsErrorRetryPolicy::new());",
runtimeConfig.awsHttp().asType()
)
}
else -> emptySection
}
}

View File

@ -82,7 +82,7 @@ class SigV4SigningFeature(private val runtimeConfig: RuntimeConfig) :
OperationCustomization() {
override fun section(section: OperationSection): Writable {
return when (section) {
is OperationSection.Feature -> writable {
is OperationSection.MutateRequest -> writable {
// TODO: this needs to be customized for individual operations, not just `default_config()`
rustTemplate(
"""

View File

@ -63,7 +63,7 @@ fun RuntimeConfig.userAgentModule() = awsHttp().asType().copy(name = "user_agent
class UserAgentFeature(private val runtimeConfig: RuntimeConfig) : OperationCustomization() {
override fun section(section: OperationSection): Writable = when (section) {
is OperationSection.Feature -> writable {
is OperationSection.MutateRequest -> writable {
rust(
"""
${section.request}.config_mut().insert(#T::AwsUserAgent::new_from_environment(crate::API_METADATA.clone()));
@ -71,6 +71,6 @@ class UserAgentFeature(private val runtimeConfig: RuntimeConfig) : OperationCust
runtimeConfig.userAgentModule()
)
}
OperationSection.ImplBlock -> emptySection
else -> emptySection
}
}

View File

@ -60,7 +60,7 @@ internal class EndpointConfigCustomizationTest {
.endpoint(&Region::from("us-east-1")).expect("default resolver produces a valid endpoint");
let mut uri = Uri::from_static("/?k=v");
endpoint.set_endpoint(&mut uri, None);
assert_eq!(uri, Uri::from_static("https://us-east-1.differentprefix.amazonaws.com/?k=v"));
assert_eq!(uri, Uri::from_static("https://differentprefix.us-east-1.amazonaws.com/?k=v"));
"""
)
}

View File

@ -215,7 +215,7 @@ tasks.register<Exec>("cargoDocs") {
workingDir(sdkOutputDir)
// disallow warnings
environment("RUSTDOCFLAGS", "-D warnings")
commandLine("cargo", "doc", "--no-deps")
commandLine("cargo", "doc", "--no-deps", "--document-private-items")
dependsOn("assemble")
}

View File

@ -9,3 +9,6 @@ edition = "2018"
[dependencies]
kms = { path = "../../build/aws-sdk/kms" }
smithy-http = { path = "../../build/aws-sdk/smithy-http" }
smithy-types = { path = "../../build/aws-sdk/smithy-types" }
http = "0.2.3"

View File

@ -3,10 +3,39 @@
* SPDX-License-Identifier: Apache-2.0.
*/
use kms::output::GenerateRandomOutput;
use kms::error::{CreateAliasError, LimitExceededError};
use kms::operation::{CreateAlias};
use kms::output::{GenerateRandomOutput, CreateAliasOutput};
use kms::Blob;
use smithy_http::result::{SdkError, SdkSuccess};
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, RetryKind};
#[test]
fn validate_sensitive_trait() {
let output = GenerateRandomOutput::builder().plaintext(Blob::new("some output")).build();
assert_eq!(format!("{:?}", output), "GenerateRandomOutput { plaintext: \"*** Sensitive Data Redacted ***\" }");
let output = GenerateRandomOutput::builder()
.plaintext(Blob::new("some output"))
.build();
assert_eq!(
format!("{:?}", output),
"GenerateRandomOutput { plaintext: \"*** Sensitive Data Redacted ***\" }"
);
}
#[test]
fn errors_are_retryable() {
let err = CreateAliasError::LimitExceededError(LimitExceededError::builder().build());
assert_eq!(err.code(), Some("LimitExceededException"));
let conf = kms::Config::builder().build();
let op = CreateAlias::builder().build(&conf);
let err =
Result::<SdkSuccess<CreateAliasOutput, &str>, SdkError<CreateAliasError, &str>>::Err(
SdkError::ServiceError {
raw: http::Response::builder().body("resp").unwrap(),
err,
},
);
let retry_kind = op.retry_policy().classify(err.as_ref());
assert_eq!(retry_kind, RetryKind::Error(ErrorKind::ThrottlingError));
}

View File

@ -68,15 +68,19 @@ class OperationInputBuilderGenerator(
) : BuilderGenerator(model, symbolProvider, shape.inputShape(model)) {
override fun buildFn(implBlockWriter: RustWriter) {
val fallibleBuilder = StructureGenerator.fallibleBuilder(shape.inputShape(model), symbolProvider)
val retryType = "()"
val returnType = "#T<#{T}, $retryType>".letIf(fallibleBuilder) { "Result<$it, String>" }
val outputSymbol = symbolProvider.toSymbol(shape)
val operationT = RuntimeType.operation(symbolProvider.config().runtimeConfig)
val operationModule = RuntimeType.operationModule(symbolProvider.config().runtimeConfig)
val sdkBody = RuntimeType.sdkBody(symbolProvider.config().runtimeConfig)
val retryType = features.mapNotNull { it.retryType() }.firstOrNull()?.let { implBlockWriter.format(it) } ?: "()"
val returnType = with(implBlockWriter) {
"${format(operationT)}<${format(outputSymbol)}, $retryType>".letIf(fallibleBuilder) { "Result<$it, String>" }
}
implBlockWriter.docs("Consumes the builder and constructs an Operation<#D>", outputSymbol)
implBlockWriter.rustBlock("pub fn build(self, _config: &#T::Config) -> $returnType", RuntimeType.Config, operationT, outputSymbol) {
// For codegen simplicity, allow `let x = ...; x`
implBlockWriter.rust("##[allow(clippy::let_and_return)]")
implBlockWriter.rustBlock("pub fn build(self, _config: &#T::Config) -> $returnType", RuntimeType.Config) {
conditionalBlock("Ok({", "})", conditional = fallibleBuilder) {
withBlock("let op = #T::new(", ");", outputSymbol) {
coreBuilder(this)
@ -88,16 +92,18 @@ class OperationInputBuilderGenerator(
""",
operationModule, sdkBody
)
features.forEach { it.section(OperationSection.Feature("request", "_config"))(this) }
features.forEach { it.section(OperationSection.MutateRequest("request", "_config"))(this) }
rust(
"""
#1T::Operation::new(
let op = #1T::Operation::new(
request,
op
).with_metadata(#1T::Metadata::new(${shape.id.name.dq()}, ${serviceName.dq()}))
).with_metadata(#1T::Metadata::new(${shape.id.name.dq()}, ${serviceName.dq()}));
""",
operationModule
operationModule,
)
features.forEach { it.section(OperationSection.FinalizeOperation("op", "_config"))(this) }
rust("op")
}
}
}

View File

@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rust.codegen.rustlang.Derives
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.documentShape
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
@ -58,6 +59,7 @@ abstract class HttpProtocolGenerator(
val builderGenerator = OperationInputBuilderGenerator(model, symbolProvider, operationShape, protocolConfig.moduleName, customizations)
builderGenerator.render(inputWriter)
// impl OperationInputShape { ... }
inputWriter.implBlock(inputShape, symbolProvider) {
toHttpRequestImpl(this, operationShape, inputShape)
val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body
@ -79,6 +81,7 @@ abstract class HttpProtocolGenerator(
}
val operationName = symbolProvider.toSymbol(operationShape).name
operationWriter.documentShape(operationShape, model)
Derives(setOf(RuntimeType.Clone)).render(operationWriter)
operationWriter.rustBlock("pub struct $operationName") {
write("input: #T", inputSymbol)
}

View File

@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.smithy.generators
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator
import software.amazon.smithy.rust.codegen.smithy.customize.Section
@ -18,7 +19,11 @@ sealed class OperationSection(name: String) : Section(name) {
* [config]: Name of the variable holding the service config.
*
* */
data class Feature(val request: String, val config: String) : OperationSection("Feature")
data class MutateRequest(val request: String, val config: String) : OperationSection("Feature")
data class FinalizeOperation(val operation: String, val config: String) : OperationSection("Finalize")
}
typealias OperationCustomization = NamedSectionGenerator<OperationSection>
abstract class OperationCustomization : NamedSectionGenerator<OperationSection>() {
open fun retryType(): RuntimeType? = null
}

View File

@ -17,7 +17,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tower::{BoxError, Layer, Service};
use tracing::field::display;
use tracing::{field, info_span, debug_span, Instrument};
use tracing::{debug_span, field, info_span, Instrument};
/// `ParseResponseService` dispatches [`Operation`](smithy_http::operation::Operation)s and parses them.
///
@ -109,11 +109,11 @@ where
Err(e) => Err(e.into()),
Ok(resp) => {
// load_response contains reading the body as far as is required & parsing the response
let response_span = debug_span!(
"load_response",
);
load_response(resp, &handler).instrument(response_span).await
},
let response_span = debug_span!("load_response",);
load_response(resp, &handler)
.instrument(response_span)
.await
}
};
match &resp {
Ok(_) => inner_span.record("status", &"ok"),

View File

@ -9,9 +9,9 @@ pub mod endpoint;
pub mod label;
pub mod middleware;
pub mod operation;
pub mod retry;
mod pin_util;
pub mod property_bag;
pub mod query;
pub mod response;
pub mod result;
pub mod retry;

View File

@ -4,6 +4,7 @@ use std::borrow::Cow;
use std::cell::{Ref, RefCell, RefMut};
use std::rc::Rc;
#[derive(Clone)]
pub struct Metadata {
operation: Cow<'static, str>,
service: Cow<'static, str>,
@ -30,6 +31,7 @@ impl Metadata {
}
#[non_exhaustive]
#[derive(Clone)]
pub struct Parts<H, R> {
pub response_handler: H,
pub retry_policy: R,
@ -50,6 +52,33 @@ impl<H, R> Operation<H, R> {
self.parts.metadata = Some(metadata);
self
}
pub fn with_retry_policy<R2>(self, retry_policy: R2) -> Operation<H, R2> {
Operation {
request: self.request,
parts: Parts {
response_handler: self.parts.response_handler,
retry_policy,
metadata: self.parts.metadata,
},
}
}
pub fn retry_policy(&self) -> &R {
&self.parts.retry_policy
}
pub fn try_clone(&self) -> Option<Self>
where
H: Clone,
R: Clone,
{
let request = self.request.try_clone()?;
Some(Self {
request,
parts: self.parts.clone(),
})
}
}
impl<H> Operation<H, ()> {

View File

@ -7,10 +7,14 @@
//!
//! For protocol agnostic retries, see `smithy_types::Retry`.
use smithy_types::retry::{ProvideErrorKind, RetryKind};
use smithy_types::retry::RetryKind;
pub trait ClassifyResponse {
fn classify<E, B>(&self, e: E, response: &http::Response<B>) -> RetryKind
where
E: ProvideErrorKind;
pub trait ClassifyResponse<T, E>: Clone {
fn classify(&self, response: Result<&T, &E>) -> RetryKind;
}
impl<T, E> ClassifyResponse<T, E> for () {
fn classify(&self, _: Result<&T, &E>) -> RetryKind {
RetryKind::NotRetryable
}
}