Add convenience to async provide credentials from a closure (#577)

This commit is contained in:
John DiSanti 2021-07-02 15:05:53 -07:00 committed by GitHub
parent 081387bda1
commit ba0d182e87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 75 deletions

View File

@ -44,17 +44,61 @@ type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
/// An asynchronous credentials provider
///
/// If your use-case is synchronous, you should implement [ProvideCredentials] instead.
/// If your use-case is synchronous, you should implement [`ProvideCredentials`] instead. Otherwise,
/// consider using [`async_provide_credentials_fn`] with a closure rather than directly implementing
/// this trait.
pub trait AsyncProvideCredentials: Send + Sync {
fn provide_credentials(&self) -> BoxFuture<CredentialsResult>;
}
pub type CredentialsProvider = Arc<dyn AsyncProvideCredentials>;
/// A [`AsyncProvideCredentials`] implemented by a closure.
///
/// See [`async_provide_credentials_fn`] for more details.
#[derive(Copy, Clone)]
pub struct AsyncProvideCredentialsFn<T: Send + Sync> {
f: T,
}
impl<T, F> AsyncProvideCredentials for AsyncProvideCredentialsFn<T>
where
T: Fn() -> F + Send + Sync,
F: Future<Output = CredentialsResult> + Send + 'static,
{
fn provide_credentials(&self) -> BoxFuture<CredentialsResult> {
Box::pin((self.f)())
}
}
/// Returns a new [`AsyncProvideCredentialsFn`] with the given closure. This allows you
/// to create an [`AsyncProvideCredentials`] implementation from an async block that returns
/// a [`CredentialsResult`].
///
/// # Example
///
/// ```
/// use aws_auth::Credentials;
/// use aws_auth::provider::async_provide_credentials_fn;
///
/// async_provide_credentials_fn(|| async {
/// // Async process to retrieve credentials goes here
/// let credentials: Credentials = todo!().await?;
/// Ok(credentials)
/// });
/// ```
pub fn async_provide_credentials_fn<T, F>(f: T) -> AsyncProvideCredentialsFn<T>
where
T: Fn() -> F + Send + Sync,
F: Future<Output = CredentialsResult> + Send + 'static,
{
AsyncProvideCredentialsFn { f }
}
/// A synchronous credentials provider
///
/// This is offered as a convenience for credential provider implementations that don't
/// need to be async. Otherwise, implement [AsyncProvideCredentials].
/// need to be async. Otherwise, implement [`AsyncProvideCredentials`].
pub trait ProvideCredentials: Send + Sync {
fn provide_credentials(&self) -> Result<Credentials, CredentialsError>;
}

View File

@ -3,9 +3,7 @@
* SPDX-License-Identifier: Apache-2.0.
*/
use aws_auth::provider::{CredentialsError, ProvideCredentials};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use aws_auth::provider::{async_provide_credentials_fn, CredentialsError};
use sts::Credentials;
/// Implements a basic version of ProvideCredentials with AWS STS
@ -14,80 +12,35 @@ use sts::Credentials;
async fn main() -> Result<(), dynamodb::Error> {
tracing_subscriber::fmt::init();
let client = sts::Client::from_env();
let sts_provider = StsCredentialsProvider {
client,
credentials: Arc::new(Mutex::new(None)),
};
sts_provider.spawn_refresh_loop().await;
// NOTE: Do not use this in production! This will grab new credentials for every request.
// A high quality caching credential provider implementation is in the roadmap.
let dynamodb_conf = dynamodb::Config::builder()
.credentials_provider(sts_provider)
.credentials_provider(async_provide_credentials_fn(move || {
let client = client.clone();
async move {
let session_token = client
.get_session_token()
.send()
.await
.map_err(|err| CredentialsError::Unhandled(Box::new(err)))?;
let sts_credentials = session_token
.credentials
.expect("should include credentials");
Ok(Credentials::new(
sts_credentials.access_key_id.unwrap(),
sts_credentials.secret_access_key.unwrap(),
sts_credentials.session_token,
sts_credentials
.expiration
.map(|expiry| expiry.to_system_time().expect("sts sent a time < 0")),
"Sts",
))
}
}))
.build();
let client = dynamodb::Client::from_conf(dynamodb_conf);
println!("tables: {:?}", client.list_tables().send().await?);
Ok(())
}
/// This is a rough example of how you could implement ProvideCredentials with Amazon STS.
///
/// Do not use this in production! A high quality implementation is in the roadmap.
#[derive(Clone)]
struct StsCredentialsProvider {
client: sts::Client,
credentials: Arc<Mutex<Option<Credentials>>>,
}
impl ProvideCredentials for StsCredentialsProvider {
fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
let inner = self.credentials.lock().unwrap().clone();
inner.ok_or(CredentialsError::CredentialsNotLoaded)
}
}
impl StsCredentialsProvider {
pub async fn spawn_refresh_loop(&self) {
let _ = self
.refresh()
.await
.map_err(|e| eprintln!("failed to load credentials! {}", e));
let this = self.clone();
tokio::spawn(async move {
loop {
let needs_refresh = {
let creds = this.credentials.lock().unwrap();
let expiry = creds.as_ref().and_then(|creds| creds.expiry());
if creds.is_none() {
true
} else {
expiry
.map(|expiry| SystemTime::now() > expiry)
.unwrap_or(false)
}
};
if needs_refresh {
let _ = this
.refresh()
.await
.map_err(|e| eprintln!("failed to load credentials! {}", e));
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
});
}
pub async fn refresh(&self) -> Result<(), sts::Error> {
let session_token = self.client.get_session_token().send().await?;
let sts_credentials = session_token
.credentials
.expect("should include credentials");
*self.credentials.lock().unwrap() = Some(Credentials::new(
sts_credentials.access_key_id.unwrap(),
sts_credentials.secret_access_key.unwrap(),
sts_credentials.session_token,
sts_credentials
.expiration
.map(|expiry| expiry.to_system_time().expect("sts sent a time < 0")),
"Sts",
));
Ok(())
}
}