mirror of https://github.com/smithy-lang/smithy-rs
Add convenience to async provide credentials from a closure (#577)
This commit is contained in:
parent
081387bda1
commit
ba0d182e87
|
@ -44,17 +44,61 @@ type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
|
|||
|
||||
/// An asynchronous credentials provider
|
||||
///
|
||||
/// If your use-case is synchronous, you should implement [ProvideCredentials] instead.
|
||||
/// If your use-case is synchronous, you should implement [`ProvideCredentials`] instead. Otherwise,
|
||||
/// consider using [`async_provide_credentials_fn`] with a closure rather than directly implementing
|
||||
/// this trait.
|
||||
pub trait AsyncProvideCredentials: Send + Sync {
|
||||
fn provide_credentials(&self) -> BoxFuture<CredentialsResult>;
|
||||
}
|
||||
|
||||
pub type CredentialsProvider = Arc<dyn AsyncProvideCredentials>;
|
||||
|
||||
/// A [`AsyncProvideCredentials`] implemented by a closure.
|
||||
///
|
||||
/// See [`async_provide_credentials_fn`] for more details.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct AsyncProvideCredentialsFn<T: Send + Sync> {
|
||||
f: T,
|
||||
}
|
||||
|
||||
impl<T, F> AsyncProvideCredentials for AsyncProvideCredentialsFn<T>
|
||||
where
|
||||
T: Fn() -> F + Send + Sync,
|
||||
F: Future<Output = CredentialsResult> + Send + 'static,
|
||||
{
|
||||
fn provide_credentials(&self) -> BoxFuture<CredentialsResult> {
|
||||
Box::pin((self.f)())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [`AsyncProvideCredentialsFn`] with the given closure. This allows you
|
||||
/// to create an [`AsyncProvideCredentials`] implementation from an async block that returns
|
||||
/// a [`CredentialsResult`].
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use aws_auth::Credentials;
|
||||
/// use aws_auth::provider::async_provide_credentials_fn;
|
||||
///
|
||||
/// async_provide_credentials_fn(|| async {
|
||||
/// // Async process to retrieve credentials goes here
|
||||
/// let credentials: Credentials = todo!().await?;
|
||||
/// Ok(credentials)
|
||||
/// });
|
||||
/// ```
|
||||
pub fn async_provide_credentials_fn<T, F>(f: T) -> AsyncProvideCredentialsFn<T>
|
||||
where
|
||||
T: Fn() -> F + Send + Sync,
|
||||
F: Future<Output = CredentialsResult> + Send + 'static,
|
||||
{
|
||||
AsyncProvideCredentialsFn { f }
|
||||
}
|
||||
|
||||
/// A synchronous credentials provider
|
||||
///
|
||||
/// This is offered as a convenience for credential provider implementations that don't
|
||||
/// need to be async. Otherwise, implement [AsyncProvideCredentials].
|
||||
/// need to be async. Otherwise, implement [`AsyncProvideCredentials`].
|
||||
pub trait ProvideCredentials: Send + Sync {
|
||||
fn provide_credentials(&self) -> Result<Credentials, CredentialsError>;
|
||||
}
|
||||
|
|
|
@ -3,9 +3,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0.
|
||||
*/
|
||||
|
||||
use aws_auth::provider::{CredentialsError, ProvideCredentials};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use aws_auth::provider::{async_provide_credentials_fn, CredentialsError};
|
||||
use sts::Credentials;
|
||||
|
||||
/// Implements a basic version of ProvideCredentials with AWS STS
|
||||
|
@ -14,80 +12,35 @@ use sts::Credentials;
|
|||
async fn main() -> Result<(), dynamodb::Error> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let client = sts::Client::from_env();
|
||||
let sts_provider = StsCredentialsProvider {
|
||||
client,
|
||||
credentials: Arc::new(Mutex::new(None)),
|
||||
};
|
||||
sts_provider.spawn_refresh_loop().await;
|
||||
|
||||
// NOTE: Do not use this in production! This will grab new credentials for every request.
|
||||
// A high quality caching credential provider implementation is in the roadmap.
|
||||
let dynamodb_conf = dynamodb::Config::builder()
|
||||
.credentials_provider(sts_provider)
|
||||
.credentials_provider(async_provide_credentials_fn(move || {
|
||||
let client = client.clone();
|
||||
async move {
|
||||
let session_token = client
|
||||
.get_session_token()
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| CredentialsError::Unhandled(Box::new(err)))?;
|
||||
let sts_credentials = session_token
|
||||
.credentials
|
||||
.expect("should include credentials");
|
||||
Ok(Credentials::new(
|
||||
sts_credentials.access_key_id.unwrap(),
|
||||
sts_credentials.secret_access_key.unwrap(),
|
||||
sts_credentials.session_token,
|
||||
sts_credentials
|
||||
.expiration
|
||||
.map(|expiry| expiry.to_system_time().expect("sts sent a time < 0")),
|
||||
"Sts",
|
||||
))
|
||||
}
|
||||
}))
|
||||
.build();
|
||||
|
||||
let client = dynamodb::Client::from_conf(dynamodb_conf);
|
||||
println!("tables: {:?}", client.list_tables().send().await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This is a rough example of how you could implement ProvideCredentials with Amazon STS.
|
||||
///
|
||||
/// Do not use this in production! A high quality implementation is in the roadmap.
|
||||
#[derive(Clone)]
|
||||
struct StsCredentialsProvider {
|
||||
client: sts::Client,
|
||||
credentials: Arc<Mutex<Option<Credentials>>>,
|
||||
}
|
||||
|
||||
impl ProvideCredentials for StsCredentialsProvider {
|
||||
fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
|
||||
let inner = self.credentials.lock().unwrap().clone();
|
||||
inner.ok_or(CredentialsError::CredentialsNotLoaded)
|
||||
}
|
||||
}
|
||||
|
||||
impl StsCredentialsProvider {
|
||||
pub async fn spawn_refresh_loop(&self) {
|
||||
let _ = self
|
||||
.refresh()
|
||||
.await
|
||||
.map_err(|e| eprintln!("failed to load credentials! {}", e));
|
||||
let this = self.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let needs_refresh = {
|
||||
let creds = this.credentials.lock().unwrap();
|
||||
let expiry = creds.as_ref().and_then(|creds| creds.expiry());
|
||||
if creds.is_none() {
|
||||
true
|
||||
} else {
|
||||
expiry
|
||||
.map(|expiry| SystemTime::now() > expiry)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
};
|
||||
if needs_refresh {
|
||||
let _ = this
|
||||
.refresh()
|
||||
.await
|
||||
.map_err(|e| eprintln!("failed to load credentials! {}", e));
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
pub async fn refresh(&self) -> Result<(), sts::Error> {
|
||||
let session_token = self.client.get_session_token().send().await?;
|
||||
let sts_credentials = session_token
|
||||
.credentials
|
||||
.expect("should include credentials");
|
||||
*self.credentials.lock().unwrap() = Some(Credentials::new(
|
||||
sts_credentials.access_key_id.unwrap(),
|
||||
sts_credentials.secret_access_key.unwrap(),
|
||||
sts_credentials.session_token,
|
||||
sts_credentials
|
||||
.expiration
|
||||
.map(|expiry| expiry.to_system_time().expect("sts sent a time < 0")),
|
||||
"Sts",
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue