mirror of https://github.com/rust-lang/rust.git
Rework coroutine transform to be more flexible in preparation for async generators
This commit is contained in:
parent
ae612bedcb
commit
a0cbc168c9
|
@ -66,9 +66,9 @@ use rustc_index::{Idx, IndexVec};
|
||||||
use rustc_middle::mir::dump_mir;
|
use rustc_middle::mir::dump_mir;
|
||||||
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
|
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
|
||||||
use rustc_middle::mir::*;
|
use rustc_middle::mir::*;
|
||||||
|
use rustc_middle::ty::CoroutineArgs;
|
||||||
use rustc_middle::ty::InstanceDef;
|
use rustc_middle::ty::InstanceDef;
|
||||||
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
|
use rustc_middle::ty::{self, Ty, TyCtxt};
|
||||||
use rustc_middle::ty::{CoroutineArgs, GenericArgsRef};
|
|
||||||
use rustc_mir_dataflow::impls::{
|
use rustc_mir_dataflow::impls::{
|
||||||
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
|
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
|
||||||
};
|
};
|
||||||
|
@ -225,8 +225,6 @@ struct SuspensionPoint<'tcx> {
|
||||||
struct TransformVisitor<'tcx> {
|
struct TransformVisitor<'tcx> {
|
||||||
tcx: TyCtxt<'tcx>,
|
tcx: TyCtxt<'tcx>,
|
||||||
coroutine_kind: hir::CoroutineKind,
|
coroutine_kind: hir::CoroutineKind,
|
||||||
state_adt_ref: AdtDef<'tcx>,
|
|
||||||
state_args: GenericArgsRef<'tcx>,
|
|
||||||
|
|
||||||
// The type of the discriminant in the coroutine struct
|
// The type of the discriminant in the coroutine struct
|
||||||
discr_ty: Ty<'tcx>,
|
discr_ty: Ty<'tcx>,
|
||||||
|
@ -245,21 +243,34 @@ struct TransformVisitor<'tcx> {
|
||||||
always_live_locals: BitSet<Local>,
|
always_live_locals: BitSet<Local>,
|
||||||
|
|
||||||
// The original RETURN_PLACE local
|
// The original RETURN_PLACE local
|
||||||
new_ret_local: Local,
|
old_ret_local: Local,
|
||||||
|
|
||||||
|
old_yield_ty: Ty<'tcx>,
|
||||||
|
|
||||||
|
old_ret_ty: Ty<'tcx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'tcx> TransformVisitor<'tcx> {
|
impl<'tcx> TransformVisitor<'tcx> {
|
||||||
fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
|
fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
|
||||||
|
assert!(matches!(self.coroutine_kind, CoroutineKind::Gen(_)));
|
||||||
|
|
||||||
let block = BasicBlock::new(body.basic_blocks.len());
|
let block = BasicBlock::new(body.basic_blocks.len());
|
||||||
|
|
||||||
let source_info = SourceInfo::outermost(body.span);
|
let source_info = SourceInfo::outermost(body.span);
|
||||||
|
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
|
||||||
|
|
||||||
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
|
|
||||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
|
||||||
let statements = vec![Statement {
|
let statements = vec![Statement {
|
||||||
kind: StatementKind::Assign(Box::new((
|
kind: StatementKind::Assign(Box::new((
|
||||||
Place::return_place(),
|
Place::return_place(),
|
||||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
Rvalue::Aggregate(
|
||||||
|
Box::new(AggregateKind::Adt(
|
||||||
|
option_def_id,
|
||||||
|
VariantIdx::from_usize(0),
|
||||||
|
self.tcx.mk_args(&[self.old_yield_ty.into()]),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
IndexVec::new(),
|
||||||
|
),
|
||||||
))),
|
))),
|
||||||
source_info,
|
source_info,
|
||||||
}];
|
}];
|
||||||
|
@ -273,23 +284,6 @@ impl<'tcx> TransformVisitor<'tcx> {
|
||||||
block
|
block
|
||||||
}
|
}
|
||||||
|
|
||||||
fn coroutine_state_adt_and_variant_idx(
|
|
||||||
&self,
|
|
||||||
is_return: bool,
|
|
||||||
) -> (AggregateKind<'tcx>, VariantIdx) {
|
|
||||||
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
|
|
||||||
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
|
|
||||||
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
|
|
||||||
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
|
|
||||||
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
|
|
||||||
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
|
|
||||||
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
|
|
||||||
});
|
|
||||||
|
|
||||||
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
|
|
||||||
(kind, idx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make a `CoroutineState` or `Poll` variant assignment.
|
// Make a `CoroutineState` or `Poll` variant assignment.
|
||||||
//
|
//
|
||||||
// `core::ops::CoroutineState` only has single element tuple variants,
|
// `core::ops::CoroutineState` only has single element tuple variants,
|
||||||
|
@ -302,51 +296,99 @@ impl<'tcx> TransformVisitor<'tcx> {
|
||||||
is_return: bool,
|
is_return: bool,
|
||||||
statements: &mut Vec<Statement<'tcx>>,
|
statements: &mut Vec<Statement<'tcx>>,
|
||||||
) {
|
) {
|
||||||
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);
|
let rvalue = match self.coroutine_kind {
|
||||||
|
|
||||||
match self.coroutine_kind {
|
|
||||||
// `Poll::Pending`
|
|
||||||
CoroutineKind::Async(_) => {
|
CoroutineKind::Async(_) => {
|
||||||
if !is_return {
|
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
|
||||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
|
||||||
|
|
||||||
// FIXME(swatinem): assert that `val` is indeed unit?
|
|
||||||
statements.push(Statement {
|
|
||||||
kind: StatementKind::Assign(Box::new((
|
|
||||||
Place::return_place(),
|
|
||||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
|
||||||
))),
|
|
||||||
source_info,
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// `Option::None`
|
|
||||||
CoroutineKind::Gen(_) => {
|
|
||||||
if is_return {
|
if is_return {
|
||||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
// Poll::Ready(val)
|
||||||
|
Rvalue::Aggregate(
|
||||||
statements.push(Statement {
|
Box::new(AggregateKind::Adt(
|
||||||
kind: StatementKind::Assign(Box::new((
|
poll_def_id,
|
||||||
Place::return_place(),
|
VariantIdx::from_usize(0),
|
||||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
args,
|
||||||
))),
|
None,
|
||||||
source_info,
|
None,
|
||||||
});
|
)),
|
||||||
return;
|
IndexVec::from_raw(vec![val]),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Poll::Pending
|
||||||
|
Rvalue::Aggregate(
|
||||||
|
Box::new(AggregateKind::Adt(
|
||||||
|
poll_def_id,
|
||||||
|
VariantIdx::from_usize(1),
|
||||||
|
args,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
IndexVec::new(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CoroutineKind::Coroutine => {}
|
CoroutineKind::Gen(_) => {
|
||||||
}
|
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
|
||||||
|
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
|
||||||
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
|
if is_return {
|
||||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
|
// None
|
||||||
|
Rvalue::Aggregate(
|
||||||
|
Box::new(AggregateKind::Adt(
|
||||||
|
option_def_id,
|
||||||
|
VariantIdx::from_usize(0),
|
||||||
|
args,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
IndexVec::new(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Some(val)
|
||||||
|
Rvalue::Aggregate(
|
||||||
|
Box::new(AggregateKind::Adt(
|
||||||
|
option_def_id,
|
||||||
|
VariantIdx::from_usize(1),
|
||||||
|
args,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
IndexVec::from_raw(vec![val]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CoroutineKind::Coroutine => {
|
||||||
|
let coroutine_state_def_id =
|
||||||
|
self.tcx.require_lang_item(LangItem::CoroutineState, None);
|
||||||
|
let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
|
||||||
|
if is_return {
|
||||||
|
// CoroutineState::Complete(val)
|
||||||
|
Rvalue::Aggregate(
|
||||||
|
Box::new(AggregateKind::Adt(
|
||||||
|
coroutine_state_def_id,
|
||||||
|
VariantIdx::from_usize(1),
|
||||||
|
args,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
IndexVec::from_raw(vec![val]),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// CoroutineState::Yielded(val)
|
||||||
|
Rvalue::Aggregate(
|
||||||
|
Box::new(AggregateKind::Adt(
|
||||||
|
coroutine_state_def_id,
|
||||||
|
VariantIdx::from_usize(0),
|
||||||
|
args,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
IndexVec::from_raw(vec![val]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
statements.push(Statement {
|
statements.push(Statement {
|
||||||
kind: StatementKind::Assign(Box::new((
|
kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
|
||||||
Place::return_place(),
|
|
||||||
Rvalue::Aggregate(Box::new(kind), [val].into()),
|
|
||||||
))),
|
|
||||||
source_info,
|
source_info,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -420,7 +462,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
|
||||||
|
|
||||||
let ret_val = match data.terminator().kind {
|
let ret_val = match data.terminator().kind {
|
||||||
TerminatorKind::Return => {
|
TerminatorKind::Return => {
|
||||||
Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
|
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
|
||||||
}
|
}
|
||||||
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
|
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
|
||||||
Some((false, Some((resume, resume_arg)), value.clone(), drop))
|
Some((false, Some((resume, resume_arg)), value.clone(), drop))
|
||||||
|
@ -1493,10 +1535,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
|
||||||
|
|
||||||
impl<'tcx> MirPass<'tcx> for StateTransform {
|
impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||||
let Some(yield_ty) = body.yield_ty() else {
|
let Some(old_yield_ty) = body.yield_ty() else {
|
||||||
// This only applies to coroutines
|
// This only applies to coroutines
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
let old_ret_ty = body.return_ty();
|
||||||
|
|
||||||
assert!(body.coroutine_drop().is_none());
|
assert!(body.coroutine_drop().is_none());
|
||||||
|
|
||||||
|
@ -1520,34 +1563,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
|
|
||||||
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
|
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
|
||||||
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
|
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
|
||||||
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
|
let new_ret_ty = match body.coroutine_kind().unwrap() {
|
||||||
CoroutineKind::Async(_) => {
|
CoroutineKind::Async(_) => {
|
||||||
// Compute Poll<return_ty>
|
// Compute Poll<return_ty>
|
||||||
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
||||||
let poll_adt_ref = tcx.adt_def(poll_did);
|
let poll_adt_ref = tcx.adt_def(poll_did);
|
||||||
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
|
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
|
||||||
(poll_adt_ref, poll_args)
|
Ty::new_adt(tcx, poll_adt_ref, poll_args)
|
||||||
}
|
}
|
||||||
CoroutineKind::Gen(_) => {
|
CoroutineKind::Gen(_) => {
|
||||||
// Compute Option<yield_ty>
|
// Compute Option<yield_ty>
|
||||||
let option_did = tcx.require_lang_item(LangItem::Option, None);
|
let option_did = tcx.require_lang_item(LangItem::Option, None);
|
||||||
let option_adt_ref = tcx.adt_def(option_did);
|
let option_adt_ref = tcx.adt_def(option_did);
|
||||||
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
|
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
|
||||||
(option_adt_ref, option_args)
|
Ty::new_adt(tcx, option_adt_ref, option_args)
|
||||||
}
|
}
|
||||||
CoroutineKind::Coroutine => {
|
CoroutineKind::Coroutine => {
|
||||||
// Compute CoroutineState<yield_ty, return_ty>
|
// Compute CoroutineState<yield_ty, return_ty>
|
||||||
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
|
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
|
||||||
let state_adt_ref = tcx.adt_def(state_did);
|
let state_adt_ref = tcx.adt_def(state_did);
|
||||||
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
|
let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
|
||||||
(state_adt_ref, state_args)
|
Ty::new_adt(tcx, state_adt_ref, state_args)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
|
|
||||||
|
|
||||||
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
|
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
|
||||||
// RETURN_PLACE then is a fresh unused local with type ret_ty.
|
// RETURN_PLACE then is a fresh unused local with type ret_ty.
|
||||||
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
|
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
|
||||||
|
|
||||||
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
|
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
|
||||||
if is_async_kind {
|
if is_async_kind {
|
||||||
|
@ -1564,9 +1606,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
} else {
|
} else {
|
||||||
body.local_decls[resume_local].ty
|
body.local_decls[resume_local].ty
|
||||||
};
|
};
|
||||||
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
|
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
|
||||||
|
|
||||||
// When first entering the coroutine, move the resume argument into its new local.
|
// When first entering the coroutine, move the resume argument into its old local
|
||||||
|
// (which is now a generator interior).
|
||||||
let source_info = SourceInfo::outermost(body.span);
|
let source_info = SourceInfo::outermost(body.span);
|
||||||
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
|
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
|
||||||
stmts.insert(
|
stmts.insert(
|
||||||
|
@ -1574,7 +1617,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
Statement {
|
Statement {
|
||||||
source_info,
|
source_info,
|
||||||
kind: StatementKind::Assign(Box::new((
|
kind: StatementKind::Assign(Box::new((
|
||||||
new_resume_local.into(),
|
old_resume_local.into(),
|
||||||
Rvalue::Use(Operand::Move(resume_local.into())),
|
Rvalue::Use(Operand::Move(resume_local.into())),
|
||||||
))),
|
))),
|
||||||
},
|
},
|
||||||
|
@ -1610,14 +1653,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
let mut transform = TransformVisitor {
|
let mut transform = TransformVisitor {
|
||||||
tcx,
|
tcx,
|
||||||
coroutine_kind: body.coroutine_kind().unwrap(),
|
coroutine_kind: body.coroutine_kind().unwrap(),
|
||||||
state_adt_ref,
|
|
||||||
state_args,
|
|
||||||
remap,
|
remap,
|
||||||
storage_liveness,
|
storage_liveness,
|
||||||
always_live_locals,
|
always_live_locals,
|
||||||
suspension_points: Vec::new(),
|
suspension_points: Vec::new(),
|
||||||
new_ret_local,
|
old_ret_local,
|
||||||
discr_ty,
|
discr_ty,
|
||||||
|
old_ret_ty,
|
||||||
|
old_yield_ty,
|
||||||
};
|
};
|
||||||
transform.visit_body(body);
|
transform.visit_body(body);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue