Fix capture analysis for by-move closure bodies

This commit is contained in:
Michael Goulet 2024-04-01 14:42:00 -04:00
parent 88c2f4f5f5
commit a1a1f41027
5 changed files with 240 additions and 17 deletions

View File

@ -3,6 +3,8 @@
//! be a coroutine body that takes all of its upvars by-move, and which we stash
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
use itertools::Itertools;
use rustc_data_structures::unord::UnordSet;
use rustc_hir as hir;
use rustc_middle::mir::visit::MutVisitor;
@ -26,36 +28,68 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
if coroutine_ty.references_error() {
return;
}
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
let args = args.as_coroutine();
let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
if coroutine_kind == ty::ClosureKind::FnOnce {
return;
}
let mut by_ref_fields = UnordSet::default();
let by_move_upvars = Ty::new_tup_from_iter(
tcx,
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
if capture.is_by_ref() {
by_ref_fields.insert(FieldIdx::from_usize(idx));
}
capture.place.ty()
}),
let parent_def_id = tcx.local_parent(coroutine_def_id);
let ty::CoroutineClosure(_, parent_args) =
*tcx.type_of(parent_def_id).instantiate_identity().kind()
else {
bug!();
};
let parent_args = parent_args.as_coroutine_closure();
let parent_upvars_ty = parent_args.tupled_upvars_ty();
let tupled_inputs_ty = tcx.instantiate_bound_regions_with_erased(
parent_args.coroutine_closure_sig().map_bound(|sig| sig.tupled_inputs_ty),
);
let num_args = tupled_inputs_ty.tuple_fields().len();
let mut by_ref_fields = UnordSet::default();
for (idx, (coroutine_capture, parent_capture)) in tcx
.closure_captures(coroutine_def_id)
.iter()
// By construction we capture all the args first.
.skip(num_args)
.zip_eq(tcx.closure_captures(parent_def_id))
.enumerate()
{
// This argument is captured by-move from the parent closure, but by-ref
// from the inner async block. That means that it's being borrowed from
// the closure body -- we need to change the coroutine take it by move.
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
}
// Make sure we're actually talking about the same capture.
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
}
let by_move_coroutine_ty = Ty::new_coroutine(
tcx,
coroutine_def_id.to_def_id(),
ty::CoroutineArgs::new(
tcx,
ty::CoroutineArgsParts {
parent_args: args.as_coroutine().parent_args(),
parent_args: args.parent_args(),
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
resume_ty: args.as_coroutine().resume_ty(),
yield_ty: args.as_coroutine().yield_ty(),
return_ty: args.as_coroutine().return_ty(),
witness: args.as_coroutine().witness(),
tupled_upvars_ty: by_move_upvars,
resume_ty: args.resume_ty(),
yield_ty: args.yield_ty(),
return_ty: args.return_ty(),
witness: args.witness(),
// Concatenate the args + closure's captures (since they're all by move).
tupled_upvars_ty: Ty::new_tup_from_iter(
tcx,
tupled_inputs_ty
.tuple_fields()
.iter()
.chain(parent_upvars_ty.tuple_fields()),
),
},
)
.args,

View File

@ -0,0 +1,89 @@
#![feature(async_closure, noop_waker)]
use std::future::Future;
use std::pin::pin;
use std::task::*;
pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
let mut fut = pin!(fut);
let ctx = &mut Context::from_waker(Waker::noop());
loop {
match fut.as_mut().poll(ctx) {
Poll::Pending => {}
Poll::Ready(t) => break t,
}
}
}
fn main() {
block_on(async_main());
}
async fn call<T>(f: &impl async Fn() -> T) -> T {
f().await
}
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
f().await
}
#[derive(Debug)]
#[allow(unused)]
struct Hello(i32);
async fn async_main() {
// Capture something by-ref
{
let x = Hello(0);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
let x = &Hello(1);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}
// Capture something and consume it (force to `AsyncFnOnce`)
{
let x = Hello(2);
let c = async || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}
// Capture something with `move`, don't consume it
{
let x = Hello(3);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
let x = &Hello(4);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
{
let x = Hello(5);
let c = async move || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}
}

View File

@ -0,0 +1,10 @@
Hello(0)
Hello(0)
Hello(1)
Hello(1)
Hello(2)
Hello(3)
Hello(3)
Hello(4)
Hello(4)
Hello(5)

View File

@ -0,0 +1,80 @@
//@ aux-build:block-on.rs
//@ edition:2021
//@ run-pass
//@ check-run-results
#![feature(async_closure)]
extern crate block_on;
fn main() {
block_on::block_on(async_main());
}
async fn call<T>(f: &impl async Fn() -> T) -> T {
f().await
}
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
f().await
}
#[derive(Debug)]
#[allow(unused)]
struct Hello(i32);
async fn async_main() {
// Capture something by-ref
{
let x = Hello(0);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
let x = &Hello(1);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}
// Capture something and consume it (force to `AsyncFnOnce`)
{
let x = Hello(2);
let c = async || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}
// Capture something with `move`, don't consume it
{
let x = Hello(3);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
let x = &Hello(4);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
{
let x = Hello(5);
let c = async move || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}
}

View File

@ -0,0 +1,10 @@
Hello(0)
Hello(0)
Hello(1)
Hello(1)
Hello(2)
Hello(3)
Hello(3)
Hello(4)
Hello(4)
Hello(5)