Split coroutine desugaring kind from source

This commit is contained in:
Michael Goulet 2023-12-21 18:49:20 +00:00
parent d6d7a93866
commit 004450506e
30 changed files with 448 additions and 239 deletions

View File

@ -670,7 +670,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
let params = arena_vec![self; param];
let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::Async(async_coroutine_source));
this.coroutine_kind = Some(hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
async_coroutine_source,
));
let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
@ -724,7 +727,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
});
let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::Gen(coroutine_source));
this.coroutine_kind = Some(hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
coroutine_source,
));
let res = body(this);
(&[], res)
@ -802,7 +808,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
let params = arena_vec![self; param];
let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source));
this.coroutine_kind = Some(hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
async_coroutine_source,
));
let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
@ -888,9 +897,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
let full_span = expr.span.to(await_kw_span);
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => {
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => false,
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => true,
Some(hir::CoroutineKind::Coroutine)
| Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _))
| None => {
return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks {
await_kw_span,
item_span: self.current_item,
@ -1123,9 +1134,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
Some(movability)
}
Some(
hir::CoroutineKind::Gen(_)
| hir::CoroutineKind::Async(_)
| hir::CoroutineKind::AsyncGen(_),
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
| hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)
| hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _),
) => {
panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering");
}
@ -1638,9 +1649,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> {
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Async(_)) => {
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) => false,
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => true,
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => {
return hir::ExprKind::Err(
self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }),
);

View File

@ -1,5 +1,4 @@
use either::Either;
use hir::PatField;
use rustc_data_structures::captures::Captures;
use rustc_data_structures::fx::FxIndexSet;
use rustc_errors::{
@ -8,6 +7,7 @@ use rustc_errors::{
use rustc_hir as hir;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::intravisit::{walk_block, walk_expr, Visitor};
use rustc_hir::{CoroutineDesugaring, PatField};
use rustc_hir::{CoroutineKind, CoroutineSource, LangItem};
use rustc_infer::traits::ObligationCause;
use rustc_middle::hir::nested_filter::OnlyBodies;
@ -2516,27 +2516,29 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
};
let kind = match use_span.coroutine_kind() {
Some(coroutine_kind) => match coroutine_kind {
CoroutineKind::Gen(kind) => match kind {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, kind) => match kind {
CoroutineSource::Block => "gen block",
CoroutineSource::Closure => "gen closure",
CoroutineSource::Fn => {
bug!("gen block/closure expected, but gen function found.")
}
},
CoroutineKind::AsyncGen(kind) => match kind {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, kind) => match kind {
CoroutineSource::Block => "async gen block",
CoroutineSource::Closure => "async gen closure",
CoroutineSource::Fn => {
bug!("gen block/closure expected, but gen function found.")
}
},
CoroutineKind::Async(async_kind) => match async_kind {
CoroutineSource::Block => "async block",
CoroutineSource::Closure => "async closure",
CoroutineSource::Fn => {
bug!("async block/closure expected, but async function found.")
CoroutineKind::Desugared(CoroutineDesugaring::Async, async_kind) => {
match async_kind {
CoroutineSource::Block => "async block",
CoroutineSource::Closure => "async closure",
CoroutineSource::Fn => {
bug!("async block/closure expected, but async function found.")
}
}
},
}
CoroutineKind::Coroutine => "coroutine",
},
None => "closure",
@ -2566,7 +2568,10 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
}
ConstraintCategory::CallArgument(_) => {
fr_name.highlight_region_name(&mut err);
if matches!(use_span.coroutine_kind(), Some(CoroutineKind::Async(_))) {
if matches!(
use_span.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
) {
err.note(
"async blocks are not executed immediately and must either take a \
reference or ownership of outside variables they use",

View File

@ -1049,7 +1049,10 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
..
}) => {
let body = map.body(*body);
if !matches!(body.coroutine_kind, Some(hir::CoroutineKind::Async(..))) {
if !matches!(
body.coroutine_kind,
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _))
) {
closure_span = Some(expr.span.shrink_to_lo());
}
}

View File

@ -684,39 +684,46 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
hir::FnRetTy::Return(hir_ty) => (fn_decl.output.span(), Some(hir_ty)),
};
let mir_description = match hir.body(body).coroutine_kind {
Some(hir::CoroutineKind::Async(src)) => match src {
hir::CoroutineSource::Block => " of async block",
hir::CoroutineSource::Closure => " of async closure",
hir::CoroutineSource::Fn => {
let parent_item =
tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
let output = &parent_item
.fn_decl()
.expect("coroutine lowered from async fn should be in fn")
.output;
span = output.span();
if let hir::FnRetTy::Return(ret) = output {
hir_ty = Some(self.get_future_inner_return_ty(*ret));
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, src)) => {
match src {
hir::CoroutineSource::Block => " of async block",
hir::CoroutineSource::Closure => " of async closure",
hir::CoroutineSource::Fn => {
let parent_item =
tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
let output = &parent_item
.fn_decl()
.expect("coroutine lowered from async fn should be in fn")
.output;
span = output.span();
if let hir::FnRetTy::Return(ret) = output {
hir_ty = Some(self.get_future_inner_return_ty(*ret));
}
" of async function"
}
" of async function"
}
},
Some(hir::CoroutineKind::Gen(src)) => match src {
hir::CoroutineSource::Block => " of gen block",
hir::CoroutineSource::Closure => " of gen closure",
hir::CoroutineSource::Fn => {
let parent_item =
tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
let output = &parent_item
.fn_decl()
.expect("coroutine lowered from gen fn should be in fn")
.output;
span = output.span();
" of gen function"
}
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, src)) => {
match src {
hir::CoroutineSource::Block => " of gen block",
hir::CoroutineSource::Closure => " of gen closure",
hir::CoroutineSource::Fn => {
let parent_item =
tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
let output = &parent_item
.fn_decl()
.expect("coroutine lowered from gen fn should be in fn")
.output;
span = output.span();
" of gen function"
}
}
},
}
Some(hir::CoroutineKind::AsyncGen(src)) => match src {
Some(hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
src,
)) => match src {
hir::CoroutineSource::Block => " of async gen block",
hir::CoroutineSource::Closure => " of async gen closure",
hir::CoroutineSource::Fn => {

View File

@ -15,7 +15,7 @@ use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::stable_hasher::{Hash64, HashStable, StableHasher};
use rustc_hir::def_id::DefId;
use rustc_hir::definitions::{DefPathData, DefPathDataName, DisambiguatedDefPathData};
use rustc_hir::{CoroutineKind, CoroutineSource, Mutability};
use rustc_hir::{CoroutineDesugaring, CoroutineKind, CoroutineSource, Mutability};
use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
use rustc_middle::ty::{self, ExistentialProjection, ParamEnv, Ty, TyCtxt};
use rustc_middle::ty::{GenericArgKind, GenericArgsRef};
@ -560,15 +560,31 @@ pub fn push_item_name(tcx: TyCtxt<'_>, def_id: DefId, qualified: bool, output: &
fn coroutine_kind_label(coroutine_kind: Option<CoroutineKind>) -> &'static str {
match coroutine_kind {
Some(CoroutineKind::Gen(CoroutineSource::Block)) => "gen_block",
Some(CoroutineKind::Gen(CoroutineSource::Closure)) => "gen_closure",
Some(CoroutineKind::Gen(CoroutineSource::Fn)) => "gen_fn",
Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block",
Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure",
Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn",
Some(CoroutineKind::AsyncGen(CoroutineSource::Block)) => "async_gen_block",
Some(CoroutineKind::AsyncGen(CoroutineSource::Closure)) => "async_gen_closure",
Some(CoroutineKind::AsyncGen(CoroutineSource::Fn)) => "async_gen_fn",
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Block)) => {
"gen_block"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Closure)) => {
"gen_closure"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Fn)) => "gen_fn",
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block)) => {
"async_block"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Closure)) => {
"async_closure"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Fn)) => {
"async_fn"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, CoroutineSource::Block)) => {
"async_gen_block"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, CoroutineSource::Closure)) => {
"async_gen_closure"
}
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, CoroutineSource::Fn)) => {
"async_gen_fn"
}
Some(CoroutineKind::Coroutine) => "coroutine",
None => "closure",
}

View File

@ -464,8 +464,12 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
Rvalue::Aggregate(kind, ..) => {
if let AggregateKind::Coroutine(def_id, ..) = kind.as_ref()
&& let Some(coroutine_kind @ hir::CoroutineKind::Async(..)) =
self.tcx.coroutine_kind(def_id)
&& let Some(
coroutine_kind @ hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
_,
),
) = self.tcx.coroutine_kind(def_id)
{
self.check_op(ops::Coroutine(coroutine_kind));
}

View File

@ -359,7 +359,11 @@ impl<'tcx> NonConstOp<'tcx> for FnCallUnstable {
pub struct Coroutine(pub hir::CoroutineKind);
impl<'tcx> NonConstOp<'tcx> for Coroutine {
fn status_in_item(&self, _: &ConstCx<'_, 'tcx>) -> Status {
if let hir::CoroutineKind::Async(hir::CoroutineSource::Block) = self.0 {
if let hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Block,
) = self.0
{
Status::Unstable(sym::const_async_blocks)
} else {
Status::Forbidden
@ -372,7 +376,11 @@ impl<'tcx> NonConstOp<'tcx> for Coroutine {
span: Span,
) -> DiagnosticBuilder<'tcx, ErrorGuaranteed> {
let msg = format!("{:#}s are not allowed in {}s", self.0, ccx.const_kind());
if let hir::CoroutineKind::Async(hir::CoroutineSource::Block) = self.0 {
if let hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Block,
) = self.0
{
ccx.tcx.sess.create_feature_err(
errors::UnallowedOpInConstContext { span, msg },
sym::const_async_blocks,

View File

@ -1351,15 +1351,8 @@ impl<'hir> Body<'hir> {
/// The type of source expression that caused this coroutine to be created.
#[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)]
pub enum CoroutineKind {
/// An explicit `async` block or the body of an `async` function.
Async(CoroutineSource),
/// An explicit `gen` block or the body of a `gen` function.
Gen(CoroutineSource),
/// An explicit `async gen` block or the body of an `async gen` function,
/// which is able to both `yield` and `.await`.
AsyncGen(CoroutineSource),
/// A coroutine that comes from a desugaring.
Desugared(CoroutineDesugaring, CoroutineSource),
/// A coroutine literal created via a `yield` inside a closure.
Coroutine,
@ -1368,31 +1361,11 @@ pub enum CoroutineKind {
impl fmt::Display for CoroutineKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CoroutineKind::Async(k) => {
if f.alternate() {
f.write_str("`async` ")?;
} else {
f.write_str("async ")?
}
CoroutineKind::Desugared(d, k) => {
d.fmt(f)?;
k.fmt(f)
}
CoroutineKind::Coroutine => f.write_str("coroutine"),
CoroutineKind::Gen(k) => {
if f.alternate() {
f.write_str("`gen` ")?;
} else {
f.write_str("gen ")?
}
k.fmt(f)
}
CoroutineKind::AsyncGen(k) => {
if f.alternate() {
f.write_str("`async gen` ")?;
} else {
f.write_str("async gen ")?
}
k.fmt(f)
}
}
}
}
@ -1425,6 +1398,49 @@ impl fmt::Display for CoroutineSource {
}
}
#[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)]
pub enum CoroutineDesugaring {
/// An explicit `async` block or the body of an `async` function.
Async,
/// An explicit `gen` block or the body of a `gen` function.
Gen,
/// An explicit `async gen` block or the body of an `async gen` function,
/// which is able to both `yield` and `.await`.
AsyncGen,
}
impl fmt::Display for CoroutineDesugaring {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CoroutineDesugaring::Async => {
if f.alternate() {
f.write_str("`async` ")?;
} else {
f.write_str("async ")?
}
}
CoroutineDesugaring::Gen => {
if f.alternate() {
f.write_str("`gen` ")?;
} else {
f.write_str("gen ")?
}
}
CoroutineDesugaring::AsyncGen => {
if f.alternate() {
f.write_str("`async gen` ")?;
} else {
f.write_str("async gen ")?
}
}
}
Ok(())
}
}
#[derive(Copy, Clone, Debug)]
pub enum BodyOwnerKind {
/// Functions and methods.

View File

@ -305,8 +305,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
) = (parent_node, callee_node)
{
let fn_decl_span = if hir.body(body).coroutine_kind
== Some(hir::CoroutineKind::Async(hir::CoroutineSource::Closure))
{
== Some(hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Closure,
)) {
// Actually need to unwrap one more layer of HIR to get to
// the _real_ closure...
let async_closure = hir.parent_id(parent_hir_id);

View File

@ -59,7 +59,8 @@ pub(super) fn check_fn<'a, 'tcx>(
&& can_be_coroutine.is_some()
{
let yield_ty = match kind {
hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
| hir::CoroutineKind::Coroutine => {
let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span,
@ -71,7 +72,7 @@ pub(super) fn check_fn<'a, 'tcx>(
// guide inference on the yield type so that we can handle `AsyncIterator`
// in this block in projection correctly. In the new trait solver, it is
// not a problem.
hir::CoroutineKind::AsyncGen(..) => {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span,
@ -89,7 +90,7 @@ pub(super) fn check_fn<'a, 'tcx>(
.into()]),
)
}
hir::CoroutineKind::Async(..) => Ty::new_unit(tcx),
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => Ty::new_unit(tcx),
};
// Resume type defaults to `()` if the coroutine has no argument.

View File

@ -634,7 +634,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// In the case of the async block that we create for a function body,
// we expect the return type of the block to match that of the enclosing
// function.
Some(hir::CoroutineKind::Async(hir::CoroutineSource::Fn)) => {
Some(hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Fn,
)) => {
debug!("closure is async fn body");
let def_id = self.tcx.hir().body_owner_def_id(body.id());
self.deduce_future_output_from_obligations(expr_def_id, def_id).unwrap_or_else(
@ -651,9 +654,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
)
}
// All `gen {}` and `async gen {}` must return unit.
Some(hir::CoroutineKind::Gen(_) | hir::CoroutineKind::AsyncGen(_)) => {
self.tcx.types.unit
}
Some(
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
| hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _),
) => self.tcx.types.unit,
_ => astconv.ty_infer(None, decl.output.span()),
},

View File

@ -17,8 +17,8 @@ use rustc_hir::def::Res;
use rustc_hir::def::{CtorKind, CtorOf, DefKind};
use rustc_hir::lang_items::LangItem;
use rustc_hir::{
CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node, Path, QPath, Stmt,
StmtKind, TyKind, WherePredicate,
CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node,
Path, QPath, Stmt, StmtKind, TyKind, WherePredicate,
};
use rustc_hir_analysis::astconv::AstConv;
use rustc_infer::traits::{self, StatementAsExpression};
@ -549,7 +549,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
ty::Coroutine(def_id, ..)
if matches!(
self.tcx.coroutine_kind(def_id),
Some(CoroutineKind::Async(CoroutineSource::Closure))
Some(CoroutineKind::Desugared(
CoroutineDesugaring::Async,
CoroutineSource::Closure
))
) =>
{
errors::SuggestBoxing::AsyncBody

View File

@ -206,16 +206,16 @@ fixed_size_enum! {
fixed_size_enum! {
hir::CoroutineKind {
( Coroutine )
( Gen(hir::CoroutineSource::Block) )
( Gen(hir::CoroutineSource::Fn) )
( Gen(hir::CoroutineSource::Closure) )
( Async(hir::CoroutineSource::Block) )
( Async(hir::CoroutineSource::Fn) )
( Async(hir::CoroutineSource::Closure) )
( AsyncGen(hir::CoroutineSource::Block) )
( AsyncGen(hir::CoroutineSource::Fn) )
( AsyncGen(hir::CoroutineSource::Closure) )
( Coroutine )
( Desugared(hir::CoroutineDesugaring::Gen, hir::CoroutineSource::Block) )
( Desugared(hir::CoroutineDesugaring::Gen, hir::CoroutineSource::Fn) )
( Desugared(hir::CoroutineDesugaring::Gen, hir::CoroutineSource::Closure) )
( Desugared(hir::CoroutineDesugaring::Async, hir::CoroutineSource::Block) )
( Desugared(hir::CoroutineDesugaring::Async, hir::CoroutineSource::Fn) )
( Desugared(hir::CoroutineDesugaring::Async, hir::CoroutineSource::Closure) )
( Desugared(hir::CoroutineDesugaring::AsyncGen, hir::CoroutineSource::Block) )
( Desugared(hir::CoroutineDesugaring::AsyncGen, hir::CoroutineSource::Fn) )
( Desugared(hir::CoroutineDesugaring::AsyncGen, hir::CoroutineSource::Closure) )
}
}

View File

@ -17,7 +17,7 @@ use rustc_data_structures::captures::Captures;
use rustc_errors::{DiagnosticArgValue, DiagnosticMessage, ErrorGuaranteed, IntoDiagnosticArg};
use rustc_hir::def::{CtorKind, Namespace};
use rustc_hir::def_id::{DefId, CRATE_DEF_ID};
use rustc_hir::{self, CoroutineKind, ImplicitSelfKind};
use rustc_hir::{self, CoroutineDesugaring, CoroutineKind, ImplicitSelfKind};
use rustc_hir::{self as hir, HirId};
use rustc_session::Session;
use rustc_target::abi::{FieldIdx, VariantIdx};

View File

@ -148,19 +148,23 @@ impl<O> AssertKind<O> {
DivisionByZero(_) => "attempt to divide by zero",
RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero",
ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => {
ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => {
"`async fn` resumed after completion"
}
ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => {
"`async gen fn` resumed after completion"
}
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => {
"`gen fn` should just keep returning `None` after completion"
}
ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => {
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => {
"`async fn` resumed after panicking"
}
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => {
"`async gen fn` resumed after panicking"
}
ResumedAfterPanic(CoroutineKind::Gen(_)) => {
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => {
"`gen fn` should just keep returning `None` after panicking"
}
@ -249,17 +253,27 @@ impl<O> AssertKind<O> {
OverflowNeg(_) => middle_assert_overflow_neg,
DivisionByZero(_) => middle_assert_divide_by_zero,
RemainderByZero(_) => middle_assert_remainder_by_zero,
ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return,
ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => todo!(),
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => {
middle_assert_async_resume_after_return
}
ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => {
todo!()
}
ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => {
bug!("gen blocks can be resumed after they return and will keep returning `None`")
}
ResumedAfterReturn(CoroutineKind::Coroutine) => {
middle_assert_coroutine_resume_after_return
}
ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic,
ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => todo!(),
ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_gen_resume_after_panic,
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => {
middle_assert_async_resume_after_panic
}
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => {
todo!()
}
ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => {
middle_assert_gen_resume_after_panic
}
ResumedAfterPanic(CoroutineKind::Coroutine) => {
middle_assert_coroutine_resume_after_panic
}

View File

@ -849,7 +849,10 @@ impl<'tcx> TyCtxt<'tcx> {
/// Returns `true` if the node pointed to by `def_id` is a coroutine for an async construct.
pub fn coroutine_is_async(self, def_id: DefId) -> bool {
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Async(_)))
matches!(
self.coroutine_kind(def_id),
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _))
)
}
/// Returns `true` if the node pointed to by `def_id` is a general coroutine that implements `Coroutine`.
@ -860,12 +863,18 @@ impl<'tcx> TyCtxt<'tcx> {
/// Returns `true` if the node pointed to by `def_id` is a coroutine for a `gen` construct.
pub fn coroutine_is_gen(self, def_id: DefId) -> bool {
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Gen(_)))
matches!(
self.coroutine_kind(def_id),
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _))
)
}
/// Returns `true` if the node pointed to by `def_id` is a coroutine for a `async gen` construct.
pub fn coroutine_is_async_gen(self, def_id: DefId) -> bool {
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::AsyncGen(_)))
matches!(
self.coroutine_kind(def_id),
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _))
)
}
pub fn stability(self) -> &'tcx stability::Index {

View File

@ -728,10 +728,16 @@ impl<'tcx> TyCtxt<'tcx> {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method",
DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => {
match coroutine_kind {
rustc_hir::CoroutineKind::Async(..) => "async closure",
rustc_hir::CoroutineKind::AsyncGen(..) => "async gen closure",
rustc_hir::CoroutineKind::Coroutine => "coroutine",
rustc_hir::CoroutineKind::Gen(..) => "gen closure",
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => {
"async closure"
}
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
"async gen closure"
}
hir::CoroutineKind::Coroutine => "coroutine",
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
"gen closure"
}
}
}
_ => def_kind.descr(def_id),
@ -749,10 +755,10 @@ impl<'tcx> TyCtxt<'tcx> {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "a",
DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => {
match coroutine_kind {
rustc_hir::CoroutineKind::Async(..) => "an",
rustc_hir::CoroutineKind::AsyncGen(..) => "an",
rustc_hir::CoroutineKind::Coroutine => "a",
rustc_hir::CoroutineKind::Gen(..) => "a",
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, ..) => "an",
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, ..) => "an",
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, ..) => "a",
hir::CoroutineKind::Coroutine => "a",
}
}
_ => def_kind.article(),

View File

@ -59,7 +59,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_errors::pluralize;
use rustc_hir as hir;
use rustc_hir::lang_items::LangItem;
use rustc_hir::CoroutineKind;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
@ -254,10 +254,12 @@ impl<'tcx> TransformVisitor<'tcx> {
let source_info = SourceInfo::outermost(body.span);
let none_value = match self.coroutine_kind {
CoroutineKind::Async(_) => span_bug!(body.span, "`Future`s are not fused inherently"),
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
span_bug!(body.span, "`Future`s are not fused inherently")
}
CoroutineKind::Coroutine => span_bug!(body.span, "`Coroutine`s cannot be fused"),
// `gen` continues return `None`
CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
@ -271,7 +273,7 @@ impl<'tcx> TransformVisitor<'tcx> {
)
}
// `async gen` continues to return `Poll::Ready(None)`
CoroutineKind::AsyncGen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
let yield_ty = args.type_at(0);
@ -316,7 +318,7 @@ impl<'tcx> TransformVisitor<'tcx> {
statements: &mut Vec<Statement<'tcx>>,
) {
let rvalue = match self.coroutine_kind {
CoroutineKind::Async(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
if is_return {
@ -345,7 +347,7 @@ impl<'tcx> TransformVisitor<'tcx> {
)
}
}
CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
if is_return {
@ -374,7 +376,7 @@ impl<'tcx> TransformVisitor<'tcx> {
)
}
}
CoroutineKind::AsyncGen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
if is_return {
let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
@ -1426,10 +1428,11 @@ fn create_coroutine_resume_function<'tcx>(
if can_return {
let block = match coroutine_kind {
CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | CoroutineKind::Coroutine => {
insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
}
CoroutineKind::AsyncGen(_) | CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
| CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
transform.insert_none_ret_block(body)
}
};
@ -1443,7 +1446,7 @@ fn create_coroutine_resume_function<'tcx>(
match coroutine_kind {
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Gen(_) => {}
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
_ => {
make_coroutine_state_argument_pinned(tcx, body);
}
@ -1609,25 +1612,34 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
};
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_)));
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
let is_async_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
);
let is_async_gen_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
);
let is_gen_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _))
);
let new_ret_ty = match body.coroutine_kind().unwrap() {
CoroutineKind::Async(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
Ty::new_adt(tcx, poll_adt_ref, poll_args)
}
CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
// Compute Option<yield_ty>
let option_did = tcx.require_lang_item(LangItem::Option, None);
let option_adt_ref = tcx.adt_def(option_did);
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
Ty::new_adt(tcx, option_adt_ref, option_args)
}
CoroutineKind::AsyncGen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
// The yield ty is already `Poll<Option<yield_ty>>`
old_yield_ty
}

View File

@ -41,17 +41,26 @@ impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineSource {
impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind {
type T = stable_mir::mir::CoroutineKind;
fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
use rustc_hir::CoroutineKind;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
match self {
CoroutineKind::Async(source) => {
stable_mir::mir::CoroutineKind::Async(source.stable(tables))
CoroutineKind::Desugared(CoroutineDesugaring::Async, source) => {
stable_mir::mir::CoroutineKind::Desugared(
stable_mir::mir::CoroutineDesugaring::Async,
source.stable(tables),
)
}
CoroutineKind::Gen(source) => {
stable_mir::mir::CoroutineKind::Gen(source.stable(tables))
CoroutineKind::Desugared(CoroutineDesugaring::Gen, source) => {
stable_mir::mir::CoroutineKind::Desugared(
stable_mir::mir::CoroutineDesugaring::Gen,
source.stable(tables),
)
}
CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine,
CoroutineKind::AsyncGen(source) => {
stable_mir::mir::CoroutineKind::AsyncGen(source.stable(tables))
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, source) => {
stable_mir::mir::CoroutineKind::Desugared(
stable_mir::mir::CoroutineDesugaring::AsyncGen,
source.stable(tables),
)
}
}
}

View File

@ -22,7 +22,7 @@ use rustc_hir::def_id::DefId;
use rustc_hir::intravisit::Visitor;
use rustc_hir::is_range_literal;
use rustc_hir::lang_items::LangItem;
use rustc_hir::{CoroutineKind, CoroutineSource, Node};
use rustc_hir::{CoroutineDesugaring, CoroutineKind, CoroutineSource, Node};
use rustc_hir::{Expr, HirId};
use rustc_infer::infer::error_reporting::TypeErrCtxt;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
@ -2578,7 +2578,10 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
.and_then(|coroutine_did| {
Some(match self.tcx.coroutine_kind(coroutine_did).unwrap() {
CoroutineKind::Coroutine => format!("coroutine is not {trait_name}"),
CoroutineKind::Async(CoroutineSource::Fn) => self
CoroutineKind::Desugared(
CoroutineDesugaring::Async,
CoroutineSource::Fn,
) => self
.tcx
.parent(coroutine_did)
.as_local()
@ -2587,13 +2590,22 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
.map(|name| {
format!("future returned by `{name}` is not {trait_name}")
})?,
CoroutineKind::Async(CoroutineSource::Block) => {
CoroutineKind::Desugared(
CoroutineDesugaring::Async,
CoroutineSource::Block,
) => {
format!("future created by async block is not {trait_name}")
}
CoroutineKind::Async(CoroutineSource::Closure) => {
CoroutineKind::Desugared(
CoroutineDesugaring::Async,
CoroutineSource::Closure,
) => {
format!("future created by async closure is not {trait_name}")
}
CoroutineKind::AsyncGen(CoroutineSource::Fn) => self
CoroutineKind::Desugared(
CoroutineDesugaring::AsyncGen,
CoroutineSource::Fn,
) => self
.tcx
.parent(coroutine_did)
.as_local()
@ -2602,27 +2614,40 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
.map(|name| {
format!("async iterator returned by `{name}` is not {trait_name}")
})?,
CoroutineKind::AsyncGen(CoroutineSource::Block) => {
CoroutineKind::Desugared(
CoroutineDesugaring::AsyncGen,
CoroutineSource::Block,
) => {
format!("async iterator created by async gen block is not {trait_name}")
}
CoroutineKind::AsyncGen(CoroutineSource::Closure) => {
CoroutineKind::Desugared(
CoroutineDesugaring::AsyncGen,
CoroutineSource::Closure,
) => {
format!(
"async iterator created by async gen closure is not {trait_name}"
)
}
CoroutineKind::Gen(CoroutineSource::Fn) => self
.tcx
.parent(coroutine_did)
.as_local()
.map(|parent_did| self.tcx.local_def_id_to_hir_id(parent_did))
.and_then(|parent_hir_id| hir.opt_name(parent_hir_id))
.map(|name| {
format!("iterator returned by `{name}` is not {trait_name}")
})?,
CoroutineKind::Gen(CoroutineSource::Block) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Fn) => {
self.tcx
.parent(coroutine_did)
.as_local()
.map(|parent_did| self.tcx.local_def_id_to_hir_id(parent_did))
.and_then(|parent_hir_id| hir.opt_name(parent_hir_id))
.map(|name| {
format!("iterator returned by `{name}` is not {trait_name}")
})?
}
CoroutineKind::Desugared(
CoroutineDesugaring::Gen,
CoroutineSource::Block,
) => {
format!("iterator created by gen block is not {trait_name}")
}
CoroutineKind::Gen(CoroutineSource::Closure) => {
CoroutineKind::Desugared(
CoroutineDesugaring::Gen,
CoroutineSource::Closure,
) => {
format!("iterator created by gen closure is not {trait_name}")
}
})
@ -3145,9 +3170,15 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
let what = match self.tcx.coroutine_kind(coroutine_def_id) {
None
| Some(hir::CoroutineKind::Coroutine)
| Some(hir::CoroutineKind::Gen(_)) => "yield",
Some(hir::CoroutineKind::Async(..)) => "await",
Some(hir::CoroutineKind::AsyncGen(_)) => "yield`/`await",
| Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) => {
"yield"
}
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => {
"await"
}
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => {
"yield`/`await"
}
};
err.note(format!(
"all values live across `{what}` must have a statically known size"
@ -3535,7 +3566,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
) {
if let Some(body_id) = self.tcx.hir().maybe_body_owned_by(obligation.cause.body_id) {
let body = self.tcx.hir().body(body_id);
if let Some(hir::CoroutineKind::Async(_)) = body.coroutine_kind {
if let Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) =
body.coroutine_kind
{
let future_trait = self.tcx.require_lang_item(LangItem::Future, None);
let self_ty = self.resolve_vars_if_possible(trait_pred.self_ty());

View File

@ -1,3 +1,5 @@
// ignore-tidy-filelength :(
use super::on_unimplemented::{AppendConstMessage, OnUnimplementedNote, TypeErrCtxtExt as _};
use super::suggestions::{get_explanation_based_on_obligation, TypeErrCtxtExt as _};
use crate::errors::{ClosureFnMutLabel, ClosureFnOnceLabel, ClosureKindMismatch};
@ -1926,15 +1928,42 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
fn describe_coroutine(&self, body_id: hir::BodyId) -> Option<&'static str> {
self.tcx.hir().body(body_id).coroutine_kind.map(|coroutine_source| match coroutine_source {
hir::CoroutineKind::Coroutine => "a coroutine",
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block",
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function",
hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure",
hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Block) => "an async gen block",
hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Fn) => "an async gen function",
hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Closure) => "an async gen closure",
hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block",
hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function",
hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Block,
) => "an async block",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Fn,
) => "an async function",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Closure,
) => "an async closure",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
hir::CoroutineSource::Block,
) => "an async gen block",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
hir::CoroutineSource::Fn,
) => "an async gen function",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
hir::CoroutineSource::Closure,
) => "an async gen closure",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
hir::CoroutineSource::Block,
) => "a gen block",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
hir::CoroutineSource::Fn,
) => "a gen function",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
hir::CoroutineSource::Closure,
) => "a gen closure",
})
}

View File

@ -114,13 +114,13 @@ fn fn_sig_for_fn_abi<'tcx>(
let pin_adt_ref = tcx.adt_def(pin_did);
let pin_args = tcx.mk_args(&[env_ty.into()]);
let env_ty = match coroutine_kind {
hir::CoroutineKind::Gen(_) => {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
env_ty
}
hir::CoroutineKind::Async(_)
| hir::CoroutineKind::AsyncGen(_)
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)
| hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)
| hir::CoroutineKind::Coroutine => Ty::new_adt(tcx, pin_adt_ref, pin_args),
};
@ -131,7 +131,7 @@ fn fn_sig_for_fn_abi<'tcx>(
// or the `Iterator::next(...) -> Option` function in case this is a
// special coroutine backing a gen construct.
let (resume_ty, ret_ty) = match coroutine_kind {
hir::CoroutineKind::Async(_) => {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => {
// The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
assert_eq!(sig.yield_ty, tcx.types.unit);
@ -156,7 +156,7 @@ fn fn_sig_for_fn_abi<'tcx>(
(Some(context_mut_ref), ret_ty)
}
hir::CoroutineKind::Gen(_) => {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
// The signature should be `Iterator::next(_) -> Option<Yield>`
let option_did = tcx.require_lang_item(LangItem::Option, None);
let option_adt_ref = tcx.adt_def(option_did);
@ -168,7 +168,7 @@ fn fn_sig_for_fn_abi<'tcx>(
(None, ret_ty)
}
hir::CoroutineKind::AsyncGen(_) => {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
// The signature should be
// `AsyncIterator::poll_next(_, &mut Context<'_>) -> Poll<Option<Output>>`
assert_eq!(sig.return_ty, tcx.types.unit);

View File

@ -288,27 +288,33 @@ impl AssertMessage {
AssertMessage::ResumedAfterReturn(CoroutineKind::Coroutine) => {
Ok("coroutine resumed after completion")
}
AssertMessage::ResumedAfterReturn(CoroutineKind::Async(_)) => {
Ok("`async fn` resumed after completion")
}
AssertMessage::ResumedAfterReturn(CoroutineKind::Gen(_)) => {
Ok("`async gen fn` resumed after completion")
}
AssertMessage::ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => {
Ok("`gen fn` should just keep returning `AssertMessage::None` after completion")
}
AssertMessage::ResumedAfterReturn(CoroutineKind::Desugared(
CoroutineDesugaring::Async,
_,
)) => Ok("`async fn` resumed after completion"),
AssertMessage::ResumedAfterReturn(CoroutineKind::Desugared(
CoroutineDesugaring::Gen,
_,
)) => Ok("`async gen fn` resumed after completion"),
AssertMessage::ResumedAfterReturn(CoroutineKind::Desugared(
CoroutineDesugaring::AsyncGen,
_,
)) => Ok("`gen fn` should just keep returning `AssertMessage::None` after completion"),
AssertMessage::ResumedAfterPanic(CoroutineKind::Coroutine) => {
Ok("coroutine resumed after panicking")
}
AssertMessage::ResumedAfterPanic(CoroutineKind::Async(_)) => {
Ok("`async fn` resumed after panicking")
}
AssertMessage::ResumedAfterPanic(CoroutineKind::Gen(_)) => {
Ok("`async gen fn` resumed after panicking")
}
AssertMessage::ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => {
Ok("`gen fn` should just keep returning `AssertMessage::None` after panicking")
}
AssertMessage::ResumedAfterPanic(CoroutineKind::Desugared(
CoroutineDesugaring::Async,
_,
)) => Ok("`async fn` resumed after panicking"),
AssertMessage::ResumedAfterPanic(CoroutineKind::Desugared(
CoroutineDesugaring::Gen,
_,
)) => Ok("`async gen fn` resumed after panicking"),
AssertMessage::ResumedAfterPanic(CoroutineKind::Desugared(
CoroutineDesugaring::AsyncGen,
_,
)) => Ok("`gen fn` should just keep returning `AssertMessage::None` after panicking"),
AssertMessage::BoundsCheck { .. } => Ok("index out of bounds"),
AssertMessage::MisalignedPointerDereference { .. } => {
@ -392,10 +398,8 @@ pub enum UnOp {
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum CoroutineKind {
Async(CoroutineSource),
Desugared(CoroutineDesugaring, CoroutineSource),
Coroutine,
Gen(CoroutineSource),
AsyncGen(CoroutineSource),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
@ -405,6 +409,15 @@ pub enum CoroutineSource {
Fn,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum CoroutineDesugaring {
Async,
Gen,
AsyncGen,
}
pub(crate) type LocalDefId = Opaque;
/// The rustc coverage data structures are heavily tied to internal details of the
/// coverage implementation that are likely to change, and are unlikely to be

View File

@ -2,7 +2,7 @@ use clippy_utils::diagnostics::span_lint_hir_and_then;
use clippy_utils::source::snippet;
use clippy_utils::ty::implements_trait;
use rustc_errors::Applicability;
use rustc_hir::{Body, BodyId, CoroutineKind, CoroutineSource, ExprKind, QPath};
use rustc_hir::{Body, BodyId, CoroutineKind, CoroutineSource, CoroutineDesugaring, ExprKind, QPath};
use rustc_lint::{LateContext, LateLintPass};
use rustc_session::declare_lint_pass;
@ -45,10 +45,9 @@ declare_lint_pass!(AsyncYieldsAsync => [ASYNC_YIELDS_ASYNC]);
impl<'tcx> LateLintPass<'tcx> for AsyncYieldsAsync {
fn check_body(&mut self, cx: &LateContext<'tcx>, body: &'tcx Body<'_>) {
use CoroutineSource::{Block, Closure};
// For functions, with explicitly defined types, don't warn.
// XXXkhuey maybe we should?
if let Some(CoroutineKind::Async(Block | Closure)) = body.coroutine_kind {
if let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block | CoroutineSource::Closure)) = body.coroutine_kind {
if let Some(future_trait_def_id) = cx.tcx.lang_items().future_trait() {
let body_id = BodyId {
hir_id: body.value.hir_id,

View File

@ -3,7 +3,7 @@ use clippy_utils::diagnostics::span_lint_and_then;
use clippy_utils::{match_def_path, paths};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def_id::DefId;
use rustc_hir::{Body, CoroutineKind, CoroutineSource};
use rustc_hir::{Body, CoroutineKind, CoroutineDesugaring};
use rustc_lint::{LateContext, LateLintPass};
use rustc_middle::mir::CoroutineLayout;
use rustc_session::impl_lint_pass;
@ -194,8 +194,7 @@ impl LateLintPass<'_> for AwaitHolding {
}
fn check_body(&mut self, cx: &LateContext<'_>, body: &'_ Body<'_>) {
use CoroutineSource::{Block, Closure, Fn};
if let Some(CoroutineKind::Async(Block | Closure | Fn)) = body.coroutine_kind {
if let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) = body.coroutine_kind {
let def_id = cx.tcx.hir().body_owner_def_id(body.id());
if let Some(coroutine_layout) = cx.tcx.mir_coroutine_witnesses(def_id) {
self.check_interior_types(cx, coroutine_layout);

View File

@ -3,7 +3,7 @@ use clippy_utils::source::{position_before_rarrow, snippet_block, snippet_opt};
use rustc_errors::Applicability;
use rustc_hir::intravisit::FnKind;
use rustc_hir::{
Block, Body, Closure, CoroutineKind, CoroutineSource, Expr, ExprKind, FnDecl, FnRetTy, GenericArg, GenericBound,
Block, Body, Closure, CoroutineKind, CoroutineSource, CoroutineDesugaring, Expr, ExprKind, FnDecl, FnRetTy, GenericArg, GenericBound,
ImplItem, Item, ItemKind, LifetimeName, Node, Term, TraitRef, Ty, TyKind, TypeBindingKind,
};
use rustc_lint::{LateContext, LateLintPass};
@ -178,7 +178,7 @@ fn desugared_async_block<'tcx>(cx: &LateContext<'tcx>, block: &'tcx Block<'tcx>)
..
} = block_expr
&& let closure_body = cx.tcx.hir().body(body)
&& closure_body.coroutine_kind == Some(CoroutineKind::Async(CoroutineSource::Block))
&& closure_body.coroutine_kind == Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block))
{
return Some(closure_body);
}

View File

@ -3,7 +3,7 @@ use clippy_utils::path_res;
use clippy_utils::source::snippet;
use rustc_errors::Applicability;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::{Block, Body, CoroutineKind, CoroutineSource, Expr, ExprKind, LangItem, MatchSource, QPath};
use rustc_hir::{Block, Body, CoroutineKind, CoroutineSource, CoroutineDesugaring, Expr, ExprKind, LangItem, MatchSource, QPath};
use rustc_lint::{LateContext, LateLintPass};
use rustc_session::declare_lint_pass;
@ -86,7 +86,7 @@ impl LateLintPass<'_> for NeedlessQuestionMark {
}
fn check_body(&mut self, cx: &LateContext<'_>, body: &'_ Body<'_>) {
if let Some(CoroutineKind::Async(CoroutineSource::Fn)) = body.coroutine_kind {
if let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Fn)) = body.coroutine_kind {
if let ExprKind::Block(
Block {
expr:

View File

@ -5,7 +5,7 @@ use clippy_utils::peel_blocks;
use clippy_utils::source::{snippet, walk_span_to_context};
use clippy_utils::visitors::for_each_expr;
use rustc_errors::Applicability;
use rustc_hir::{Closure, CoroutineKind, CoroutineSource, Expr, ExprKind, MatchSource};
use rustc_hir::{Closure, CoroutineKind, CoroutineSource, CoroutineDesugaring, Expr, ExprKind, MatchSource};
use rustc_lint::{LateContext, LateLintPass};
use rustc_middle::lint::in_external_macro;
use rustc_middle::ty::UpvarCapture;
@ -71,7 +71,7 @@ impl<'tcx> LateLintPass<'tcx> for RedundantAsyncBlock {
fn desugar_async_block<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -> Option<&'tcx Expr<'tcx>> {
if let ExprKind::Closure(Closure { body, def_id, .. }) = expr.kind
&& let body = cx.tcx.hir().body(*body)
&& matches!(body.coroutine_kind, Some(CoroutineKind::Async(CoroutineSource::Block)))
&& matches!(body.coroutine_kind, Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block)))
{
cx.typeck_results()
.closure_min_captures

View File

@ -5,7 +5,7 @@ use clippy_utils::sugg::Sugg;
use rustc_errors::Applicability;
use rustc_hir as hir;
use rustc_hir::intravisit::{Visitor as HirVisitor, Visitor};
use rustc_hir::{intravisit as hir_visit, CoroutineKind, CoroutineSource, Node};
use rustc_hir::{intravisit as hir_visit, CoroutineKind, CoroutineSource, CoroutineDesugaring, Node};
use rustc_lint::{LateContext, LateLintPass};
use rustc_middle::hir::nested_filter;
use rustc_middle::lint::in_external_macro;
@ -67,7 +67,7 @@ fn is_async_closure(cx: &LateContext<'_>, body: &hir::Body<'_>) -> bool {
if let hir::ExprKind::Closure(innermost_closure_generated_by_desugar) = body.value.kind
&& let desugared_inner_closure_body = cx.tcx.hir().body(innermost_closure_generated_by_desugar.body)
// checks whether it is `async || whatever_expression`
&& let Some(CoroutineKind::Async(CoroutineSource::Closure)) = desugared_inner_closure_body.coroutine_kind
&& let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Closure)) = desugared_inner_closure_body.coroutine_kind
{
true
} else {

View File

@ -86,7 +86,13 @@ impl<'a, 'tcx> Visitor<'tcx> for AsyncFnVisitor<'a, 'tcx> {
}
fn visit_body(&mut self, b: &'tcx Body<'tcx>) {
let is_async_block = matches!(b.coroutine_kind, Some(rustc_hir::CoroutineKind::Async(_)));
let is_async_block = matches!(
b.coroutine_kind,
Some(rustc_hir::CoroutineKind::Desugared(
rustc_hir::CoroutineDesugaring::Async,
_
))
);
if is_async_block {
self.async_depth += 1;