mirror of https://github.com/smithy-lang/smithy-rs
Implement timeouts for LazyCachingCredentialsProvider (#595)
* Implement timeouts for LazyCachingCredentialsProvider * Rename refresh to reload * Update CHANGELOG * Fix clippy * CR feedback * Add note about panic on `LazyCachedCredentialsProvider` builder * Fix doc comment code reference
This commit is contained in:
parent
03ae7cc6a1
commit
70999dadce
|
@ -1,5 +1,6 @@
|
|||
## vNext (Month Day Year)
|
||||
**New This Week**
|
||||
- :tada: Add LazyCachingCredentialsProvider to aws-auth for use with expiring credentials, such as STS AssumeRole. Update STS example to use this new provider (#578, #595)
|
||||
- :bug: Correctly encode HTTP Checksums using base64 instead of hex. Fixes aws-sdk-rust#164. (#615)
|
||||
- Update SDK gradle build logic to use gradle properties (#620)
|
||||
- (When complete) Add profile file provider for region (#594, #xyz)
|
||||
|
|
|
@ -5,8 +5,14 @@ authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Russell Cohen <rcoh@a
|
|||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
rt-tokio = ["smithy-async/rt-tokio"]
|
||||
default = ["rt-tokio"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
pin-project = "1"
|
||||
smithy-async = { path = "../../../rust-runtime/smithy-async", default-features = false }
|
||||
smithy-http = { path = "../../../rust-runtime/smithy-http" }
|
||||
aws-types = { path = "../aws-types" }
|
||||
tokio = { version = "1", features = ["sync"] }
|
||||
|
@ -20,3 +26,4 @@ http = "0.2.3"
|
|||
test-env-log = { version = "0.2.7", features = ["trace"] }
|
||||
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "test-util"] }
|
||||
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
|
||||
smithy-async = { path = "../../../rust-runtime/smithy-async", features = ["rt-tokio"] }
|
||||
|
|
|
@ -3,6 +3,16 @@
|
|||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! AWS credential providers, generic caching provider implementations, and traits to implement custom providers.
|
||||
//!
|
||||
//! Credentials providers acquire AWS credentials from environment variables, files,
|
||||
//! or calls to AWS services such as STS. Custom credential provider implementations can
|
||||
//! be provided by implementing [`ProvideCredentials`] for synchronous use-cases, or
|
||||
//! [`AsyncProvideCredentials`] for async use-cases. Generic credential caching implementations,
|
||||
//! for example,
|
||||
//! [`LazyCachingCredentialsProvider`](crate::provider::lazy_caching::LazyCachingCredentialsProvider),
|
||||
//! are also provided as part of this module.
|
||||
|
||||
mod cache;
|
||||
pub mod env;
|
||||
pub mod lazy_caching;
|
||||
|
@ -16,11 +26,13 @@ use std::fmt::{Debug, Display, Formatter};
|
|||
use std::future::{self, Future};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum CredentialsError {
|
||||
CredentialsNotLoaded,
|
||||
ProviderTimedOut(Duration),
|
||||
Unhandled(Box<dyn Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
|
@ -28,6 +40,11 @@ impl Display for CredentialsError {
|
|||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
CredentialsError::CredentialsNotLoaded => write!(f, "CredentialsNotLoaded"),
|
||||
CredentialsError::ProviderTimedOut(d) => write!(
|
||||
f,
|
||||
"Credentials provider timed out after {} seconds",
|
||||
d.as_secs()
|
||||
),
|
||||
CredentialsError::Unhandled(err) => write!(f, "{}", err),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Credential provider implementation that pulls from environment variables
|
||||
|
||||
use crate::provider::{CredentialsError, ProvideCredentials};
|
||||
use crate::Credentials;
|
||||
use aws_types::os_shim_internal::Env;
|
||||
|
|
|
@ -3,53 +3,51 @@
|
|||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Lazy, caching, credentials provider implementation
|
||||
|
||||
use crate::provider::cache::Cache;
|
||||
use crate::provider::time::TimeSource;
|
||||
use crate::provider::{AsyncProvideCredentials, BoxFuture, CredentialsResult};
|
||||
use crate::provider::{AsyncProvideCredentials, BoxFuture, CredentialsError, CredentialsResult};
|
||||
use smithy_async::future::timeout::Timeout;
|
||||
use smithy_async::rt::sleep::AsyncSleep;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{trace_span, Instrument};
|
||||
|
||||
const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const DEFAULT_CREDENTIAL_EXPIRATION: Duration = Duration::from_secs(15 * 60);
|
||||
const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10);
|
||||
|
||||
// TODO: Implement async runtime-agnostic timeouts
|
||||
// TODO: Add catch_unwind() to handle panics
|
||||
// TODO: Update doc comment below once catch_unwind and timeouts are implemented
|
||||
// TODO: Update warning not to use this in the STS example once it's prod ready
|
||||
|
||||
/// `LazyCachingCredentialsProvider` implements [`AsyncProvideCredentials`] by caching
|
||||
/// credentials that it loads by calling a user-provided [`AsyncProvideCredentials`] implementation.
|
||||
///
|
||||
/// For example, you can provide an [`AsyncProvideCredentials`] implementation that calls
|
||||
/// AWS STS's AssumeRole operation to get temporary credentials, and `LazyCachingCredentialsProvider`
|
||||
/// will cache those credentials until they expire.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This is __NOT__ production ready yet. Timeouts and panic safety have not been implemented yet.
|
||||
pub struct LazyCachingCredentialsProvider {
|
||||
time: Box<dyn TimeSource>,
|
||||
sleeper: Box<dyn AsyncSleep>,
|
||||
cache: Cache,
|
||||
refresh: Arc<dyn AsyncProvideCredentials>,
|
||||
_refresh_timeout: Duration,
|
||||
loader: Arc<dyn AsyncProvideCredentials>,
|
||||
load_timeout: Duration,
|
||||
default_credential_expiration: Duration,
|
||||
}
|
||||
|
||||
impl LazyCachingCredentialsProvider {
|
||||
fn new(
|
||||
time: impl TimeSource,
|
||||
refresh: Arc<dyn AsyncProvideCredentials>,
|
||||
refresh_timeout: Duration,
|
||||
sleeper: Box<dyn AsyncSleep>,
|
||||
loader: Arc<dyn AsyncProvideCredentials>,
|
||||
load_timeout: Duration,
|
||||
default_credential_expiration: Duration,
|
||||
buffer_time: Duration,
|
||||
) -> Self {
|
||||
LazyCachingCredentialsProvider {
|
||||
time: Box::new(time),
|
||||
sleeper,
|
||||
cache: Cache::new(buffer_time),
|
||||
refresh,
|
||||
_refresh_timeout: refresh_timeout,
|
||||
loader,
|
||||
load_timeout,
|
||||
default_credential_expiration,
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +64,9 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
|
|||
Self: 'a,
|
||||
{
|
||||
let now = self.time.now();
|
||||
let refresh = self.refresh.clone();
|
||||
let loader = self.loader.clone();
|
||||
let timeout_future = self.sleeper.sleep(self.load_timeout);
|
||||
let load_timeout = self.load_timeout;
|
||||
let cache = self.cache.clone();
|
||||
let default_credential_expiration = self.default_credential_expiration;
|
||||
|
||||
|
@ -75,16 +75,18 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
|
|||
if let Some(credentials) = cache.yield_or_clear_if_expired(now).await {
|
||||
Ok(credentials)
|
||||
} else {
|
||||
// If we didn't get credentials from the cache, then we need to try and refresh.
|
||||
// There may be other threads also refreshing simultaneously, but this is OK
|
||||
// If we didn't get credentials from the cache, then we need to try and load.
|
||||
// There may be other threads also loading simultaneously, but this is OK
|
||||
// since the futures are not eagerly executed, and the cache will only run one
|
||||
// of them.
|
||||
let span = trace_span!("lazy_refresh_credentials");
|
||||
let future = refresh.provide_credentials();
|
||||
let span = trace_span!("lazy_load_credentials");
|
||||
let future = Timeout::new(loader.provide_credentials(), timeout_future);
|
||||
cache
|
||||
.get_or_load(|| {
|
||||
async move {
|
||||
let mut credentials = future.await?;
|
||||
let mut credentials = future
|
||||
.await
|
||||
.map_err(|_| CredentialsError::ProviderTimedOut(load_timeout))??;
|
||||
// If the credentials don't have an expiration time, then create a default one
|
||||
if credentials.expiry().is_none() {
|
||||
*credentials.expiry_mut() =
|
||||
|
@ -92,7 +94,7 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
|
|||
}
|
||||
Ok(credentials)
|
||||
}
|
||||
// Only instrument the the actual refreshing future so that no span
|
||||
// Only instrument the the actual load future so that no span
|
||||
// is opened if the cache decides not to execute it.
|
||||
.instrument(span)
|
||||
})
|
||||
|
@ -105,10 +107,11 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
|
|||
pub mod builder {
|
||||
use crate::provider::lazy_caching::{
|
||||
LazyCachingCredentialsProvider, DEFAULT_BUFFER_TIME, DEFAULT_CREDENTIAL_EXPIRATION,
|
||||
DEFAULT_REFRESH_TIMEOUT,
|
||||
DEFAULT_LOAD_TIMEOUT,
|
||||
};
|
||||
use crate::provider::time::SystemTimeSource;
|
||||
use crate::provider::AsyncProvideCredentials;
|
||||
use smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
|
@ -124,7 +127,7 @@ pub mod builder {
|
|||
/// use std::time::Duration;
|
||||
///
|
||||
/// let provider = LazyCachingCredentialsProvider::builder()
|
||||
/// .refresh(async_provide_credentials_fn(|| async {
|
||||
/// .load(async_provide_credentials_fn(|| async {
|
||||
/// // An async process to retrieve credentials would go here:
|
||||
/// Ok(Credentials::from_keys("example", "example", None))
|
||||
/// }))
|
||||
|
@ -132,8 +135,9 @@ pub mod builder {
|
|||
/// ```
|
||||
#[derive(Default)]
|
||||
pub struct Builder {
|
||||
refresh: Option<Arc<dyn AsyncProvideCredentials>>,
|
||||
refresh_timeout: Option<Duration>,
|
||||
sleep: Option<Box<dyn AsyncSleep>>,
|
||||
load: Option<Arc<dyn AsyncProvideCredentials>>,
|
||||
load_timeout: Option<Duration>,
|
||||
buffer_time: Option<Duration>,
|
||||
default_credential_expiration: Option<Duration>,
|
||||
}
|
||||
|
@ -143,18 +147,27 @@ pub mod builder {
|
|||
Default::default()
|
||||
}
|
||||
|
||||
/// An implementation of [`AsyncProvideCredentials`] that will be used to refresh
|
||||
/// An implementation of [`AsyncProvideCredentials`] that will be used to load
|
||||
/// the cached credentials once they're expired.
|
||||
pub fn refresh(mut self, refresh: impl AsyncProvideCredentials + 'static) -> Self {
|
||||
self.refresh = Some(Arc::new(refresh));
|
||||
pub fn load(mut self, loader: impl AsyncProvideCredentials + 'static) -> Self {
|
||||
self.load = Some(Arc::new(loader));
|
||||
self
|
||||
}
|
||||
|
||||
/// Implementation of [`AsyncSleep`] to use for timeouts. This enables use of
|
||||
/// the `LazyCachingCredentialsProvider` with other async runtimes.
|
||||
/// If using Tokio as the async runtime, this should be set to an instance of
|
||||
/// [`TokioSleep`](smithy_async::rt::sleep::TokioSleep).
|
||||
pub fn sleep(mut self, sleep: impl AsyncSleep + 'static) -> Self {
|
||||
self.sleep = Some(Box::new(sleep));
|
||||
self
|
||||
}
|
||||
|
||||
/// (Optional) Timeout for the given [`AsyncProvideCredentials`] implementation.
|
||||
/// Defaults to 5 seconds.
|
||||
pub fn refresh_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.refresh_timeout = Some(timeout);
|
||||
unimplemented!("refresh_timeout hasn't been implemented yet")
|
||||
pub fn load_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.load_timeout = Some(timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// (Optional) Amount of time before the actual credential expiration time
|
||||
|
@ -176,6 +189,11 @@ pub mod builder {
|
|||
}
|
||||
|
||||
/// Creates the [`LazyCachingCredentialsProvider`].
|
||||
///
|
||||
/// ## Note:
|
||||
/// This will panic if no `sleep` implementation is given and if no default crate features
|
||||
/// are used. By default, the [`TokioSleep`](smithy_async::rt::sleep::TokioSleep)
|
||||
/// implementation will be set automatically.
|
||||
pub fn build(self) -> LazyCachingCredentialsProvider {
|
||||
let default_credential_expiration = self
|
||||
.default_credential_expiration
|
||||
|
@ -186,8 +204,11 @@ pub mod builder {
|
|||
);
|
||||
LazyCachingCredentialsProvider::new(
|
||||
SystemTimeSource,
|
||||
self.refresh.expect("refresh provider is required"),
|
||||
self.refresh_timeout.unwrap_or(DEFAULT_REFRESH_TIMEOUT),
|
||||
self.sleep.unwrap_or_else(|| {
|
||||
default_async_sleep().expect("no default sleep implementation available")
|
||||
}),
|
||||
self.load.expect("load implementation is required"),
|
||||
self.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT),
|
||||
self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME),
|
||||
default_credential_expiration,
|
||||
)
|
||||
|
@ -199,12 +220,13 @@ pub mod builder {
|
|||
mod tests {
|
||||
use crate::provider::lazy_caching::{
|
||||
LazyCachingCredentialsProvider, TimeSource, DEFAULT_BUFFER_TIME,
|
||||
DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_REFRESH_TIMEOUT,
|
||||
DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_LOAD_TIMEOUT,
|
||||
};
|
||||
use crate::provider::{
|
||||
async_provide_credentials_fn, AsyncProvideCredentials, CredentialsError, CredentialsResult,
|
||||
};
|
||||
use crate::Credentials;
|
||||
use smithy_async::rt::sleep::TokioSleep;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tracing::info;
|
||||
|
@ -234,20 +256,21 @@ mod tests {
|
|||
|
||||
fn test_provider<T: TimeSource>(
|
||||
time: T,
|
||||
refresh_list: Vec<CredentialsResult>,
|
||||
load_list: Vec<CredentialsResult>,
|
||||
) -> LazyCachingCredentialsProvider {
|
||||
let refresh_list = Arc::new(Mutex::new(refresh_list));
|
||||
let load_list = Arc::new(Mutex::new(load_list));
|
||||
LazyCachingCredentialsProvider::new(
|
||||
time,
|
||||
Box::new(TokioSleep::new()),
|
||||
Arc::new(async_provide_credentials_fn(move || {
|
||||
let list = refresh_list.clone();
|
||||
let list = load_list.clone();
|
||||
async move {
|
||||
let next = list.lock().unwrap().remove(0);
|
||||
info!("refreshing the credentials to {:?}", next);
|
||||
next
|
||||
}
|
||||
})),
|
||||
DEFAULT_REFRESH_TIMEOUT,
|
||||
DEFAULT_LOAD_TIMEOUT,
|
||||
DEFAULT_CREDENTIAL_EXPIRATION,
|
||||
DEFAULT_BUFFER_TIME,
|
||||
)
|
||||
|
@ -272,14 +295,15 @@ mod tests {
|
|||
#[test_env_log::test(tokio::test)]
|
||||
async fn initial_populate_credentials() {
|
||||
let time = TestTime::new(epoch_secs(100));
|
||||
let refresh = Arc::new(async_provide_credentials_fn(|| async {
|
||||
let loader = Arc::new(async_provide_credentials_fn(|| async {
|
||||
info!("refreshing the credentials");
|
||||
Ok(credentials(1000))
|
||||
}));
|
||||
let provider = LazyCachingCredentialsProvider::new(
|
||||
time,
|
||||
refresh,
|
||||
DEFAULT_REFRESH_TIMEOUT,
|
||||
Box::new(TokioSleep::new()),
|
||||
loader,
|
||||
DEFAULT_LOAD_TIMEOUT,
|
||||
DEFAULT_CREDENTIAL_EXPIRATION,
|
||||
DEFAULT_BUFFER_TIME,
|
||||
);
|
||||
|
@ -295,7 +319,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test_env_log::test(tokio::test)]
|
||||
async fn refresh_expired_credentials() {
|
||||
async fn reload_expired_credentials() {
|
||||
let time = TestTime::new(epoch_secs(100));
|
||||
let time_inner = time.time.clone();
|
||||
let provider = test_provider(
|
||||
|
@ -318,7 +342,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test_env_log::test(tokio::test)]
|
||||
async fn refresh_failed_error() {
|
||||
async fn load_failed_error() {
|
||||
let time = TestTime::new(epoch_secs(100));
|
||||
let time_inner = time.time.clone();
|
||||
let provider = test_provider(
|
||||
|
@ -335,8 +359,9 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test_env_log::test]
|
||||
fn refresh_retrieve_contention() {
|
||||
fn load_contention() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_time()
|
||||
.worker_threads(16)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
@ -377,4 +402,25 @@ mod tests {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_env_log::test(tokio::test)]
|
||||
async fn load_timeout() {
|
||||
let time = TestTime::new(epoch_secs(100));
|
||||
let provider = LazyCachingCredentialsProvider::new(
|
||||
time,
|
||||
Box::new(TokioSleep::new()),
|
||||
Arc::new(async_provide_credentials_fn(|| async {
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
Ok(credentials(1000))
|
||||
})),
|
||||
Duration::from_millis(5),
|
||||
DEFAULT_CREDENTIAL_EXPIRATION,
|
||||
DEFAULT_BUFFER_TIME,
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
provider.provide_credentials().await,
|
||||
Err(CredentialsError::ProviderTimedOut(_))
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ val smithyVersion: String by project
|
|||
|
||||
val sdkOutputDir = buildDir.resolve("aws-sdk")
|
||||
val runtimeModules = listOf(
|
||||
"smithy-async",
|
||||
"smithy-types",
|
||||
"smithy-json",
|
||||
"smithy-query",
|
||||
|
|
|
@ -14,10 +14,11 @@ async fn main() -> Result<(), dynamodb::Error> {
|
|||
tracing_subscriber::fmt::init();
|
||||
let client = sts::Client::from_env();
|
||||
|
||||
// NOTE: Do not use LazyCachingCredentialsProvider in production yet!
|
||||
// It hasn't implemented timeout or panic safety yet.
|
||||
// `LazyCachingCredentialsProvider` will load credentials if it doesn't have any non-expired
|
||||
// credentials cached. See the docs on the builder for the various configuration options,
|
||||
// such as timeouts, default expiration times, and more.
|
||||
let sts_provider = LazyCachingCredentialsProvider::builder()
|
||||
.refresh(async_provide_credentials_fn(move || {
|
||||
.load(async_provide_credentials_fn(move || {
|
||||
let client = client.clone();
|
||||
async move {
|
||||
let session_token = client
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "smithy-async"
|
||||
version = "0.1.0"
|
||||
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "John DiSanti <jdisanti@amazon.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
rt-tokio = ["tokio"]
|
||||
default = ["rt-tokio"]
|
||||
|
||||
[dependencies]
|
||||
pin-project-lite = "0.2"
|
||||
tokio = { version = "1.6", features = ["time"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.6", features = ["rt", "macros"] }
|
|
@ -0,0 +1,9 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Useful runtime-agnostic future implementations.
|
||||
|
||||
pub mod never;
|
||||
pub mod timeout;
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Provides the [`Never`] future that never completes.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Future that never completes.
|
||||
#[non_exhaustive]
|
||||
#[derive(Default)]
|
||||
pub struct Never;
|
||||
|
||||
impl Never {
|
||||
pub fn new() -> Never {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Never {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
// This code was copied and then modified from Tokio.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2021 Tokio Contributors
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any
|
||||
* person obtaining a copy of this software and associated
|
||||
* documentation files (the "Software"), to deal in the
|
||||
* Software without restriction, including without
|
||||
* limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of
|
||||
* the Software, and to permit persons to whom the Software
|
||||
* is furnished to do so, subject to the following
|
||||
* conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice
|
||||
* shall be included in all copies or substantial portions
|
||||
* of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
||||
* ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
* TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
||||
* SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
||||
* IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
//! Provides the [`Timeout`] future for adding a timeout to another future.
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct TimedOutError;
|
||||
|
||||
impl Error for TimedOutError {}
|
||||
|
||||
impl fmt::Display for TimedOutError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "timed out")
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
#[non_exhaustive]
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
#[derive(Debug)]
|
||||
pub struct Timeout<T, S> {
|
||||
#[pin]
|
||||
value: T,
|
||||
#[pin]
|
||||
sleep: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> Timeout<T, S> {
|
||||
pub fn new(value: T, sleep: S) -> Timeout<T, S> {
|
||||
Timeout { value, sleep }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> Future for Timeout<T, S>
|
||||
where
|
||||
T: Future,
|
||||
S: Future,
|
||||
{
|
||||
type Output = Result<T::Output, TimedOutError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let me = self.project();
|
||||
|
||||
// First, try polling the future
|
||||
if let Poll::Ready(v) = me.value.poll(cx) {
|
||||
return Poll::Ready(Ok(v));
|
||||
}
|
||||
|
||||
// Now check the timer
|
||||
match me.sleep.poll(cx) {
|
||||
Poll::Ready(_) => Poll::Ready(Err(TimedOutError)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{TimedOutError, Timeout};
|
||||
use crate::future::never::Never;
|
||||
|
||||
#[tokio::test]
|
||||
async fn success() {
|
||||
assert_eq!(
|
||||
Ok(Ok(5)),
|
||||
Timeout::new(async { Ok::<isize, isize>(5) }, Never).await
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn failure() {
|
||||
assert_eq!(
|
||||
Ok(Err(0)),
|
||||
Timeout::new(async { Err::<isize, isize>(0) }, Never).await
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn timeout() {
|
||||
assert_eq!(Err(TimedOutError), Timeout::new(Never, async {}).await);
|
||||
}
|
||||
|
||||
// If the value is available at the same time as the timeout, then return the value
|
||||
#[tokio::test]
|
||||
async fn prefer_value_to_timeout() {
|
||||
assert_eq!(Ok(5), Timeout::new(async { 5 }, async {}).await);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Future utilities and runtime-agnostic abstractions for smithy-rs.
|
||||
//!
|
||||
//! Async runtime specific code is abstracted behind async traits, and implementations are
|
||||
//! provided via feature flag. For now, only Tokio runtime implementations are provided.
|
||||
|
||||
pub mod future;
|
||||
pub mod rt;
|
|
@ -0,0 +1,8 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Async runtime agnostic traits and implementations.
|
||||
|
||||
pub mod sleep;
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
//! Provides an [`AsyncSleep`] trait that returns a future that sleeps for a given duration,
|
||||
//! and implementations of `AsyncSleep` for different async runtimes.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Async trait with a `sleep` function.
|
||||
pub trait AsyncSleep: std::fmt::Debug + Send + Sync {
|
||||
/// Returns a future that sleeps for the given `duration` of time.
|
||||
fn sleep(&self, duration: Duration) -> Sleep;
|
||||
}
|
||||
|
||||
/// Returns a default sleep implementation based on the features enabled, or `None` if
|
||||
/// there isn't one available from this crate.
|
||||
pub fn default_async_sleep() -> Option<Box<dyn AsyncSleep>> {
|
||||
sleep_tokio()
|
||||
}
|
||||
|
||||
/// Future returned by [`AsyncSleep`].
|
||||
#[non_exhaustive]
|
||||
pub struct Sleep(Pin<Box<dyn Future<Output = ()> + Send + 'static>>);
|
||||
|
||||
impl Sleep {
|
||||
fn new(future: impl Future<Output = ()> + Send + 'static) -> Sleep {
|
||||
Sleep(Box::pin(future))
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Sleep {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.0.as_mut().poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AsyncSleep`] for Tokio.
|
||||
#[non_exhaustive]
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TokioSleep;
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
impl TokioSleep {
|
||||
pub fn new() -> TokioSleep {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
impl AsyncSleep for TokioSleep {
|
||||
fn sleep(&self, duration: Duration) -> Sleep {
|
||||
Sleep::new(tokio::time::sleep(duration))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
fn sleep_tokio() -> Option<Box<dyn AsyncSleep>> {
|
||||
Some(Box::new(TokioSleep::new()))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "rt-tokio"))]
|
||||
fn sleep_tokio() -> Option<Box<dyn AsyncSleep>> {
|
||||
None
|
||||
}
|
Loading…
Reference in New Issue