Add lifetimes to async traits that take args by reference (#3061)

This PR adds lifetimes to the `IdentityResolver`, `DnsResolver` (renamed
to `ResolveDns`), and `EndpointResolver` traits so that lifetime
gymnastics aren't needed when implementing those traits. For example,
`IdentityResolver::resolve_identity` takes `&ConfigBag` as an argument,
which means you have to pull things out of the ConfigBag outside of any
returned async block in order for the compiler to be satisfied. This
change removes that consideration and makes implementing these traits a
lot easier.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
This commit is contained in:
John DiSanti 2023-10-13 09:56:50 -07:00 committed by GitHub
parent 8439f2ae43
commit d293d1f762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 106 additions and 96 deletions

View File

@ -343,3 +343,15 @@ message = """
references = ["smithy-rs#3032"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "ysaito1001"
[[aws-sdk-rust]]
message = "Lifetimes have been added to `EndpointResolver` and `IdentityResolver` traits."
references = ["smithy-rs#3061"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"
[[smithy-rs]]
message = "Lifetimes have been added to the `EndpointResolver` trait."
references = ["smithy-rs#3061"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

View File

@ -16,7 +16,7 @@ allowed_external_types = [
"aws_smithy_http::endpoint",
"aws_smithy_http::endpoint::error::InvalidEndpointError",
"aws_smithy_http::result::SdkError",
"aws_smithy_runtime_api::client::dns::DnsResolver",
"aws_smithy_runtime_api::client::dns::ResolveDns",
"aws_smithy_runtime_api::client::dns::SharedDnsResolver",
"aws_smithy_runtime_api::client::http::HttpClient",
"aws_smithy_runtime_api::client::http::SharedHttpClient",

View File

@ -50,7 +50,7 @@ use crate::http_credential_provider::HttpCredentialProvider;
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_runtime_api::client::dns::{DnsResolver, ResolveDnsError, SharedDnsResolver};
use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_types::error::display::DisplayErrorContext;
@ -272,9 +272,9 @@ impl Builder {
/// Override the DNS resolver used to validate URIs
///
/// URIs must refer to loopback addresses. The [`DnsResolver`](aws_smithy_runtime_api::client::dns::DnsResolver)
/// is used to retrieve IP addresses for a given domain.
pub fn dns(mut self, dns: impl DnsResolver + 'static) -> Self {
/// URIs must refer to loopback addresses. The [`ResolveDns`](aws_smithy_runtime_api::client::dns::ResolveDns)
/// implementation is used to retrieve IP addresses for a given domain.
pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
self.dns = Some(dns.into_shared());
self
}
@ -399,7 +399,7 @@ async fn validate_full_uri(
Ok(addr) => addr.is_loopback(),
Err(_domain_name) => {
let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
dns.resolve_dns(host.to_owned())
dns.resolve_dns(host)
.await
.map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
.iter()
@ -751,16 +751,16 @@ mod test {
}
}
impl DnsResolver for TestDns {
fn resolve_dns(&self, name: String) -> DnsFuture {
DnsFuture::ready(Ok(self.addrs.get(&name).unwrap_or(&self.fallback).clone()))
impl ResolveDns for TestDns {
fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
}
}
#[derive(Debug)]
struct NeverDns;
impl DnsResolver for NeverDns {
fn resolve_dns(&self, _name: String) -> DnsFuture {
impl ResolveDns for NeverDns {
fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
DnsFuture::new(async {
Never::new().await;
unreachable!()

View File

@ -524,11 +524,10 @@ struct ImdsEndpointResolver {
}
impl EndpointResolver for ImdsEndpointResolver {
fn resolve_endpoint(&self, _: &EndpointResolverParams) -> EndpointFuture {
let this = self.clone();
fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
EndpointFuture::new(async move {
this.endpoint_source
.endpoint(this.mode_override)
self.endpoint_source
.endpoint(self.mode_override.clone())
.await
.map(|uri| Endpoint::builder().url(uri.to_string()).build())
.map_err(|err| err.into())

View File

@ -192,23 +192,19 @@ fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Toke
}
impl IdentityResolver for TokenResolver {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
let this = self.clone();
IdentityFuture::new(async move {
let preloaded_token = this
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
IdentityFuture::new(async {
let preloaded_token = self
.inner
.cache
.yield_or_clear_if_expired(this.inner.time_source.now())
.yield_or_clear_if_expired(self.inner.time_source.now())
.await;
let token = match preloaded_token {
Some(token) => Ok(token),
None => {
this.inner
self.inner
.cache
.get_or_load(|| {
let this = this.clone();
async move { this.get_token().await }
})
.get_or_load(|| async { self.get_token().await })
.await
}
}?;

View File

@ -20,7 +20,9 @@ use tokio::sync::oneshot::{Receiver, Sender};
/// Endpoint reloader
#[must_use]
pub struct ReloadEndpoint {
loader: Box<dyn Fn() -> BoxFuture<(Endpoint, SystemTime), ResolveEndpointError> + Send + Sync>,
loader: Box<
dyn Fn() -> BoxFuture<'static, (Endpoint, SystemTime), ResolveEndpointError> + Send + Sync,
>,
endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
error: Arc<Mutex<Option<ResolveEndpointError>>>,
rx: Receiver<()>,

View File

@ -23,7 +23,7 @@ pub mod credentials {
}
impl IdentityResolver for CredentialsIdentityResolver {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
let cache = self.credentials_cache.clone();
IdentityFuture::new(async move {
let credentials = cache.as_ref().provide_cached_credentials().await?;

View File

@ -15,4 +15,4 @@ pub mod rendezvous;
pub mod timeout;
/// A boxed future that outputs a `Result<T, E>`.
pub type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
pub type BoxFuture<'a, T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;

View File

@ -3,28 +3,54 @@
* SPDX-License-Identifier: Apache-2.0
*/
/// Declares a new-type for a future that is returned from an async trait (prior to stable async trait).
///
/// To declare a future with a static lifetime:
/// ```ignore
/// new_type_future! {
/// doc = "some rustdoc for the future's struct",
/// pub struct NameOfFuture<'static, OutputType, ErrorType>;
/// }
/// ```
///
/// To declare a future with a non-static lifetime:
/// ```ignore
/// new_type_future! {
/// doc = "some rustdoc for the future's struct",
/// pub struct NameOfFuture<'a, OutputType, ErrorType>;
/// }
/// ```
macro_rules! new_type_future {
(
doc = $type_docs:literal,
pub struct $future_name:ident<$output:ty, $err:ty>,
#[doc = $type_docs:literal]
pub struct $future_name:ident<'static, $output:ty, $err:ty>;
) => {
new_type_future!(@internal, $type_docs, $future_name, $output, $err, 'static,);
};
(
#[doc = $type_docs:literal]
pub struct $future_name:ident<$lifetime:lifetime, $output:ty, $err:ty>;
) => {
new_type_future!(@internal, $type_docs, $future_name, $output, $err, $lifetime, <$lifetime>);
};
(@internal, $type_docs:literal, $future_name:ident, $output:ty, $err:ty, $lifetime:lifetime, $($decl_lifetime:tt)*) => {
pin_project_lite::pin_project! {
#[allow(clippy::type_complexity)]
#[doc = $type_docs]
pub struct $future_name {
pub struct $future_name$($decl_lifetime)* {
#[pin]
inner: aws_smithy_async::future::now_or_later::NowOrLater<
Result<$output, $err>,
aws_smithy_async::future::BoxFuture<$output, $err>
aws_smithy_async::future::BoxFuture<$lifetime, $output, $err>
>,
}
}
impl $future_name {
impl$($decl_lifetime)* $future_name$($decl_lifetime)* {
#[doc = concat!("Create a new `", stringify!($future_name), "` with the given future.")]
pub fn new<F>(future: F) -> Self
where
F: std::future::Future<Output = Result<$output, $err>> + Send + 'static,
F: std::future::Future<Output = Result<$output, $err>> + Send + $lifetime,
{
Self {
inner: aws_smithy_async::future::now_or_later::NowOrLater::new(Box::pin(future)),
@ -38,7 +64,7 @@ macro_rules! new_type_future {
")]
pub fn new_boxed(
future: std::pin::Pin<
Box<dyn std::future::Future<Output = Result<$output, $err>> + Send>,
Box<dyn std::future::Future<Output = Result<$output, $err>> + Send + $lifetime>,
>,
) -> Self {
Self {
@ -54,7 +80,7 @@ macro_rules! new_type_future {
}
}
impl std::future::Future for $future_name {
impl$($decl_lifetime)* std::future::Future for $future_name$($decl_lifetime)* {
type Output = Result<$output, $err>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {

View File

@ -7,14 +7,10 @@
use crate::box_error::BoxError;
use crate::impl_shared_conversions;
use aws_smithy_async::future::now_or_later::NowOrLater;
use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Error that occurs when failing to perform a DNS lookup.
#[derive(Debug)]
@ -43,57 +39,35 @@ impl StdError for ResolveDnsError {
}
}
type BoxFuture<T> = aws_smithy_async::future::BoxFuture<T, ResolveDnsError>;
/// New-type for the future returned by the [`DnsResolver`] trait.
pub struct DnsFuture(NowOrLater<Result<Vec<IpAddr>, ResolveDnsError>, BoxFuture<Vec<IpAddr>>>);
impl DnsFuture {
/// Create a new `DnsFuture`
pub fn new(
future: impl Future<Output = Result<Vec<IpAddr>, ResolveDnsError>> + Send + 'static,
) -> Self {
Self(NowOrLater::new(Box::pin(future)))
}
/// Create a `DnsFuture` that is immediately ready
pub fn ready(result: Result<Vec<IpAddr>, ResolveDnsError>) -> Self {
Self(NowOrLater::ready(result))
}
}
impl Future for DnsFuture {
type Output = Result<Vec<IpAddr>, ResolveDnsError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
let inner = Pin::new(&mut this.0);
Future::poll(inner, cx)
}
new_type_future! {
#[doc = "New-type for the future returned by the [`ResolveDns`] trait."]
pub struct DnsFuture<'a, Vec<IpAddr>, ResolveDnsError>;
}
/// Trait for resolving domain names
pub trait DnsResolver: fmt::Debug + Send + Sync {
pub trait ResolveDns: fmt::Debug + Send + Sync {
/// Asynchronously resolve the given domain name
fn resolve_dns(&self, name: String) -> DnsFuture;
fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a>;
}
/// Shared DNS resolver
/// Shared instance of [`ResolveDns`].
#[derive(Clone, Debug)]
pub struct SharedDnsResolver(Arc<dyn DnsResolver>);
pub struct SharedDnsResolver(Arc<dyn ResolveDns>);
impl SharedDnsResolver {
/// Create a new `SharedDnsResolver`.
pub fn new(resolver: impl DnsResolver + 'static) -> Self {
pub fn new(resolver: impl ResolveDns + 'static) -> Self {
Self(Arc::new(resolver))
}
}
impl DnsResolver for SharedDnsResolver {
fn resolve_dns(&self, name: String) -> DnsFuture {
impl ResolveDns for SharedDnsResolver {
fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
self.0.resolve_dns(name)
}
}
impl_shared_conversions!(convert SharedDnsResolver from DnsResolver using SharedDnsResolver::new);
impl_shared_conversions!(convert SharedDnsResolver from ResolveDns using SharedDnsResolver::new);
#[cfg(test)]
mod tests {
@ -102,6 +76,6 @@ mod tests {
#[test]
fn check_send() {
fn is_send<T: Send>() {}
is_send::<DnsFuture>();
is_send::<DnsFuture<'_>>();
}
}

View File

@ -14,8 +14,8 @@ use std::fmt;
use std::sync::Arc;
new_type_future! {
doc = "Future for [`EndpointResolver::resolve_endpoint`].",
pub struct EndpointFuture<Endpoint, BoxError>,
#[doc = "Future for [`EndpointResolver::resolve_endpoint`]."]
pub struct EndpointFuture<'a, Endpoint, BoxError>;
}
/// Parameters originating from the Smithy endpoint ruleset required for endpoint resolution.
@ -45,7 +45,7 @@ impl Storable for EndpointResolverParams {
/// Configurable endpoint resolver implementation.
pub trait EndpointResolver: Send + Sync + fmt::Debug {
/// Asynchronously resolves an endpoint to use from the given endpoint parameters.
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> EndpointFuture;
fn resolve_endpoint<'a>(&'a self, params: &'a EndpointResolverParams) -> EndpointFuture<'a>;
}
/// Shared endpoint resolver.
@ -62,7 +62,7 @@ impl SharedEndpointResolver {
}
impl EndpointResolver for SharedEndpointResolver {
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> EndpointFuture {
fn resolve_endpoint<'a>(&'a self, params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
self.0.resolve_endpoint(params)
}
}

View File

@ -62,8 +62,8 @@ use std::sync::Arc;
use std::time::Duration;
new_type_future! {
doc = "Future for [`HttpConnector::call`].",
pub struct HttpConnectorFuture<HttpResponse, ConnectorError>,
#[doc = "Future for [`HttpConnector::call`]."]
pub struct HttpConnectorFuture<'static, HttpResponse, ConnectorError>;
}
/// Trait with a `call` function that asynchronously converts a request into a response.

View File

@ -17,8 +17,8 @@ use std::time::SystemTime;
pub mod http;
new_type_future! {
doc = "Future for [`IdentityResolver::resolve_identity`].",
pub struct IdentityFuture<Identity, BoxError>,
#[doc = "Future for [`IdentityResolver::resolve_identity`]."]
pub struct IdentityFuture<'a, Identity, BoxError>;
}
/// Resolver for identities.
@ -34,7 +34,7 @@ new_type_future! {
/// There is no fallback to other auth schemes in the absence of an identity.
pub trait IdentityResolver: Send + Sync + Debug {
/// Asynchronously resolves an identity for a request using the given config.
fn resolve_identity(&self, config_bag: &ConfigBag) -> IdentityFuture;
fn resolve_identity<'a>(&'a self, config_bag: &'a ConfigBag) -> IdentityFuture<'a>;
}
/// Container for a shared identity resolver.
@ -49,7 +49,7 @@ impl SharedIdentityResolver {
}
impl IdentityResolver for SharedIdentityResolver {
fn resolve_identity(&self, config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
self.0.resolve_identity(config_bag)
}
}

View File

@ -64,7 +64,7 @@ impl From<String> for Token {
}
impl IdentityResolver for Token {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
IdentityFuture::ready(Ok(Identity::new(self.clone(), self.0.expiration)))
}
}
@ -123,7 +123,7 @@ impl Login {
}
impl IdentityResolver for Login {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
IdentityFuture::ready(Ok(Identity::new(self.clone(), self.0.expiration)))
}
}

View File

@ -579,7 +579,7 @@ impl RuntimeComponentsBuilder {
#[derive(Debug)]
struct FakeEndpointResolver;
impl EndpointResolver for FakeEndpointResolver {
fn resolve_endpoint(&self, _: &EndpointResolverParams) -> EndpointFuture {
fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
unreachable!("fake endpoint resolver must be overridden for this test")
}
}
@ -606,7 +606,7 @@ impl RuntimeComponentsBuilder {
#[derive(Debug)]
struct FakeIdentityResolver;
impl IdentityResolver for FakeIdentityResolver {
fn resolve_identity(&self, _: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _: &'a ConfigBag) -> IdentityFuture<'a> {
unreachable!("fake identity resolver must be overridden for this test")
}
}

View File

@ -7,7 +7,7 @@
#[cfg(all(feature = "rt-tokio", not(target_family = "wasm")))]
mod tokio {
use aws_smithy_runtime_api::client::dns::{DnsFuture, DnsResolver, ResolveDnsError};
use aws_smithy_runtime_api::client::dns::{DnsFuture, ResolveDns, ResolveDnsError};
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::net::ToSocketAddrs;
@ -25,8 +25,9 @@ mod tokio {
}
}
impl DnsResolver for TokioDnsResolver {
fn resolve_dns(&self, name: String) -> DnsFuture {
impl ResolveDns for TokioDnsResolver {
fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
let name = name.to_string();
DnsFuture::new(async move {
let result = tokio::task::spawn_blocking(move || (name, 0).to_socket_addrs()).await;
match result {

View File

@ -333,7 +333,7 @@ pub struct LoggingDnsResolver {
impl Service<Name> for LoggingDnsResolver {
type Response = Once<SocketAddr>;
type Error = Infallible;
type Future = BoxFuture<Self::Response, Self::Error>;
type Future = BoxFuture<'static, Self::Response, Self::Error>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))

View File

@ -29,7 +29,7 @@ impl NoAuthIdentityResolver {
}
impl IdentityResolver for NoAuthIdentityResolver {
fn resolve_identity(&self, _: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _: &'a ConfigBag) -> IdentityFuture<'a> {
IdentityFuture::ready(Ok(Identity::new(NoAuthIdentity::new(), None)))
}
}

View File

@ -160,7 +160,7 @@ mod tests {
#[derive(Debug)]
struct TestIdentityResolver;
impl IdentityResolver for TestIdentityResolver {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
IdentityFuture::ready(Ok(Identity::new("doesntmatter", None)))
}
}

View File

@ -46,7 +46,7 @@ impl StaticUriEndpointResolver {
}
impl EndpointResolver for StaticUriEndpointResolver {
fn resolve_endpoint(&self, _params: &EndpointResolverParams) -> EndpointFuture {
fn resolve_endpoint<'a>(&'a self, _params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
EndpointFuture::ready(Ok(Endpoint::builder()
.url(self.endpoint.to_string())
.build()))
@ -101,7 +101,7 @@ impl<Params> EndpointResolver for DefaultEndpointResolver<Params>
where
Params: Debug + Send + Sync + 'static,
{
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> EndpointFuture {
fn resolve_endpoint<'a>(&'a self, params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
let ep = match params.get::<Params>() {
Some(params) => self.inner.resolve_endpoint(params).map_err(Box::new),
None => Err(Box::new(ResolveEndpointError::message(