diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index de43f9faff909..320d8fd3977ae 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -58,16 +58,24 @@ //! borrowing from the outer closure, and we simply peel off a `deref` projection //! from them. This second body is stored alongside the first body, and optimized //! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`, -//! we use this "by move" body instead. - -use itertools::Itertools; +//! we use this "by-move" body instead. +//! +//! ## How does this work? +//! +//! This pass essentially remaps the body of the (child) closure of the coroutine-closure +//! to take the set of upvars of the parent closure by value. This at least requires +//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure +//! captures something by value; however, it may also require renumbering field indices +//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine +//! to split one field capture into two. -use rustc_data_structures::unord::UnordSet; +use rustc_data_structures::unord::UnordMap; use rustc_hir as hir; +use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind}; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::{self, dump_mir, MirPass}; use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt}; -use rustc_target::abi::FieldIdx; +use rustc_target::abi::{FieldIdx, VariantIdx}; pub struct ByMoveBody; @@ -116,32 +124,116 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { .tuple_fields() .len(); - let mut by_ref_fields = UnordSet::default(); - for (idx, (coroutine_capture, parent_capture)) in tcx + let mut field_remapping = UnordMap::default(); + + // One parent capture may correspond to several child captures if we end up + // refining the set of captures via edition-2021 precise captures. We want to + // match up any number of child captures with one parent capture, so we keep + // peeking off this `Peekable` until the child doesn't match anymore. + let mut parent_captures = + tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable(); + // Make sure we use every field at least once, b/c why are we capturing something + // if it's not used in the inner coroutine. + let mut field_used_at_least_once = false; + + for (child_field_idx, child_capture) in tcx .closure_captures(coroutine_def_id) .iter() + .copied() // By construction we capture all the args first. .skip(num_args) - .zip_eq(tcx.closure_captures(parent_def_id)) .enumerate() { - // This upvar is captured by-move from the parent closure, but by-ref - // from the inner async block. That means that it's being borrowed from - // the outer closure body -- we need to change the coroutine to take the - // upvar by value. - if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() { - assert_ne!( - coroutine_kind, - ty::ClosureKind::FnOnce, - "`FnOnce` coroutine-closures return coroutines that capture from \ - their body; it will always result in a borrowck error!" + loop { + let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else { + bug!("we ran out of parent captures!") + }; + + let PlaceBase::Upvar(parent_base) = parent_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + let PlaceBase::Upvar(child_base) = child_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + + assert!( + child_capture.place.projections.len() >= parent_capture.place.projections.len() ); - by_ref_fields.insert(FieldIdx::from_usize(num_args + idx)); + // A parent matches a child they share the same prefix of projections. + // The child may have more, if it is capturing sub-fields out of + // something that is captured by-move in the parent closure. + if parent_base.var_path.hir_id != child_base.var_path.hir_id + || !std::iter::zip( + &child_capture.place.projections, + &parent_capture.place.projections, + ) + .all(|(child, parent)| child.kind == parent.kind) + { + // Make sure the field was used at least once. + assert!( + field_used_at_least_once, + "we captured {parent_capture:#?} but it was not used in the child coroutine?" + ); + field_used_at_least_once = false; + // Skip this field. + let _ = parent_captures.next().unwrap(); + continue; + } + + // Store this set of additional projections (fields and derefs). + // We need to re-apply them later. + let child_precise_captures = + &child_capture.place.projections[parent_capture.place.projections.len()..]; + + // If the parent captures by-move, and the child captures by-ref, then we + // need to peel an additional `deref` off of the body of the child. + let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref(); + if needs_deref { + assert_ne!( + coroutine_kind, + ty::ClosureKind::FnOnce, + "`FnOnce` coroutine-closures return coroutines that capture from \ + their body; it will always result in a borrowck error!" + ); + } + + // Finally, store the type of the parent's captured place. We need + // this when building the field projection in the MIR body later on. + let mut parent_capture_ty = parent_capture.place.ty(); + parent_capture_ty = match parent_capture.info.capture_kind { + ty::UpvarCapture::ByValue => parent_capture_ty, + ty::UpvarCapture::ByRef(kind) => Ty::new_ref( + tcx, + tcx.lifetimes.re_erased, + parent_capture_ty, + kind.to_mutbl_lossy(), + ), + }; + + field_remapping.insert( + FieldIdx::from_usize(child_field_idx + num_args), + ( + FieldIdx::from_usize(parent_field_idx + num_args), + parent_capture_ty, + needs_deref, + child_precise_captures, + ), + ); + + field_used_at_least_once = true; + break; } + } + + // Pop the last parent capture + if field_used_at_least_once { + let _ = parent_captures.next().unwrap(); + } + assert_eq!(parent_captures.next(), None, "leftover parent captures?"); - // Make sure we're actually talking about the same capture. - // FIXME(async_closures): We could look at the `hir::Upvar` instead? - assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty()); + if coroutine_kind == ty::ClosureKind::FnOnce { + assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len()); + return; } let by_move_coroutine_ty = tcx @@ -157,7 +249,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { ); let mut by_move_body = body.clone(); - MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body); dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim { coroutine_def_id: coroutine_def_id.to_def_id(), @@ -168,7 +260,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { struct MakeByMoveBody<'tcx> { tcx: TyCtxt<'tcx>, - by_ref_fields: UnordSet, + field_remapping: UnordMap, bool, &'tcx [Projection<'tcx>])>, by_move_coroutine_ty: Ty<'tcx>, } @@ -183,24 +275,59 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { context: mir::visit::PlaceContext, location: mir::Location, ) { + // Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a + // field projection. If this is in `field_remapping`, then it must not be an + // arg from calling the closure, but instead an upvar. if place.local == ty::CAPTURE_STRUCT_LOCAL - && let Some((&mir::ProjectionElem::Field(idx, ty), projection)) = + && let Some((&mir::ProjectionElem::Field(idx, _), projection)) = place.projection.split_first() - && self.by_ref_fields.contains(&idx) + && let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) = + self.field_remapping.get(&idx) { - let (begin, end) = projection.split_first().unwrap(); - // FIXME(async_closures): I'm actually a bit surprised to see that we always - // initially deref the by-ref upvars. If this is not actually true, then we - // will at least get an ICE that explains why this isn't true :^) - assert_eq!(*begin, mir::ProjectionElem::Deref); - // Peel one ref off of the ty. - let peeled_ty = ty.builtin_deref(true).unwrap().ty; + // As noted before, if the parent closure captures a field by value, and + // the child captures a field by ref, then for the by-move body we're + // generating, we also are taking that field by value. Peel off a deref, + // since a layer of reffing has now become redundant. + let final_deref = if needs_deref { + let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first() + else { + bug!( + "There should be at least a single deref for an upvar local initialization, found {projection:#?}" + ); + }; + // There may be more derefs, since we may also implicitly reborrow + // a captured mut pointer. + projection + } else { + projection + }; + + // The only thing that should be left is a deref, if the parent captured + // an upvar by-ref. + std::assert_matches::assert_matches!(final_deref, [] | [mir::ProjectionElem::Deref]); + + // For all of the additional projections that come out of precise capturing, + // re-apply these projections. + let additional_projections = + additional_projections.iter().map(|elem| match elem.kind { + ProjectionKind::Deref => mir::ProjectionElem::Deref, + ProjectionKind::Field(idx, VariantIdx::ZERO) => { + mir::ProjectionElem::Field(idx, elem.ty) + } + _ => unreachable!("precise captures only through fields and derefs"), + }); + + // We start out with an adjusted field index (and ty), representing the + // upvar that we get from our parent closure. We apply any of the additional + // projections to make sure that to the rest of the body of the closure, the + // place looks the same, and then apply that final deref if necessary. *place = mir::Place { local: place.local, projection: self.tcx.mk_place_elems_from_iter( - [mir::ProjectionElem::Field(idx, peeled_ty)] + [mir::ProjectionElem::Field(remapped_idx, remapped_ty)] .into_iter() - .chain(end.iter().copied()), + .chain(additional_projections) + .chain(final_deref.iter().copied()), ), }; } diff --git a/tests/ui/async-await/async-closures/mut-ref-reborrow.rs b/tests/ui/async-await/async-closures/mut-ref-reborrow.rs new file mode 100644 index 0000000000000..9f2cbd7ce1c3d --- /dev/null +++ b/tests/ui/async-await/async-closures/mut-ref-reborrow.rs @@ -0,0 +1,27 @@ +//@ aux-build:block-on.rs +//@ run-pass +//@ check-run-results +//@ revisions: e2021 e2018 +//@[e2018] edition:2018 +//@[e2021] edition:2021 + +#![feature(async_closure)] + +extern crate block_on; + +async fn call_once(f: impl async FnOnce()) { f().await; } + +pub async fn async_closure(x: &mut i32) { + let c = async move || { + *x += 1; + }; + call_once(c).await; +} + +fn main() { + block_on::block_on(async { + let mut x = 0; + async_closure(&mut x).await; + assert_eq!(x, 1); + }); +} diff --git a/tests/ui/async-await/async-closures/overlapping-projs.rs b/tests/ui/async-await/async-closures/overlapping-projs.rs new file mode 100644 index 0000000000000..6dd00b16103f7 --- /dev/null +++ b/tests/ui/async-await/async-closures/overlapping-projs.rs @@ -0,0 +1,27 @@ +//@ aux-build:block-on.rs +//@ edition:2021 +//@ run-pass +//@ check-run-results + +#![feature(async_closure)] + +extern crate block_on; + +async fn call_once(f: impl async FnOnce()) { + f().await; +} + +async fn async_main() { + let x = &mut 0; + let y = &mut 0; + let c = async || { + *x = 1; + *y = 2; + }; + call_once(c).await; + println!("{x} {y}"); +} + +fn main() { + block_on::block_on(async_main()); +} diff --git a/tests/ui/async-await/async-closures/overlapping-projs.run.stdout b/tests/ui/async-await/async-closures/overlapping-projs.run.stdout new file mode 100644 index 0000000000000..8d04f961a0371 --- /dev/null +++ b/tests/ui/async-await/async-closures/overlapping-projs.run.stdout @@ -0,0 +1 @@ +1 2 diff --git a/tests/ui/async-await/async-closures/precise-captures.call.run.stdout b/tests/ui/async-await/async-closures/precise-captures.call.run.stdout new file mode 100644 index 0000000000000..6062556837c74 --- /dev/null +++ b/tests/ui/async-await/async-closures/precise-captures.call.run.stdout @@ -0,0 +1,29 @@ +after call +after await +fixed +uncaptured + +after call +after await +fixed +uncaptured + +after call +after await +fixed +uncaptured + +after call +after await +fixed +untouched + +after call +drop first +after await +uncaptured + +after call +drop first +after await +uncaptured diff --git a/tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout b/tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout new file mode 100644 index 0000000000000..ddb02d4760001 --- /dev/null +++ b/tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout @@ -0,0 +1,29 @@ +after call +after await +fixed +uncaptured + +after call +after await +fixed +uncaptured + +after call +fixed +after await +uncaptured + +after call +after await +fixed +untouched + +after call +drop first +after await +uncaptured + +after call +drop first +after await +uncaptured diff --git a/tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout b/tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout new file mode 100644 index 0000000000000..ddb02d4760001 --- /dev/null +++ b/tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout @@ -0,0 +1,29 @@ +after call +after await +fixed +uncaptured + +after call +after await +fixed +uncaptured + +after call +fixed +after await +uncaptured + +after call +after await +fixed +untouched + +after call +drop first +after await +uncaptured + +after call +drop first +after await +uncaptured diff --git a/tests/ui/async-await/async-closures/precise-captures.rs b/tests/ui/async-await/async-closures/precise-captures.rs new file mode 100644 index 0000000000000..e82dd1dbaf059 --- /dev/null +++ b/tests/ui/async-await/async-closures/precise-captures.rs @@ -0,0 +1,157 @@ +//@ aux-build:block-on.rs +//@ edition:2021 +//@ run-pass +//@ check-run-results +//@ revisions: call call_once force_once + +// call - Call the closure regularly. +// call_once - Call the closure w/ `async FnOnce`, so exercising the by_move shim. +// force_once - Force the closure mode to `FnOnce`, so exercising what was fixed +// in . + +#![feature(async_closure)] +#![allow(unused_mut)] + +extern crate block_on; + +#[cfg(any(call, force_once))] +macro_rules! call { + ($c:expr) => { ($c)() } +} + +#[cfg(call_once)] +async fn call_once(f: impl async FnOnce()) { + f().await +} + +#[cfg(call_once)] +macro_rules! call { + ($c:expr) => { call_once($c) } +} + +#[cfg(not(force_once))] +macro_rules! guidance { + ($c:expr) => { $c } +} + +#[cfg(force_once)] +fn infer_fnonce(c: impl async FnOnce()) -> impl async FnOnce() { c } + +#[cfg(force_once)] +macro_rules! guidance { + ($c:expr) => { infer_fnonce($c) } +} + +#[derive(Debug)] +struct Drop(&'static str); + +impl std::ops::Drop for Drop { + fn drop(&mut self) { + println!("{}", self.0); + } +} + +struct S { + a: i32, + b: Drop, + c: Drop, +} + +async fn async_main() { + // Precise capture struct + { + let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") }; + let mut c = guidance!(async || { + s.a = 2; + let w = &mut s.b; + w.0 = "fixed"; + }); + s.c.0 = "uncaptured"; + let fut = call!(c); + println!("after call"); + fut.await; + println!("after await"); + } + println!(); + + // Precise capture &mut struct + { + let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") }; + let mut c = guidance!(async || { + s.a = 2; + let w = &mut s.b; + w.0 = "fixed"; + }); + s.c.0 = "uncaptured"; + let fut = call!(c); + println!("after call"); + fut.await; + println!("after await"); + } + println!(); + + // Precise capture struct by move + { + let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") }; + let mut c = guidance!(async move || { + s.a = 2; + let w = &mut s.b; + w.0 = "fixed"; + }); + s.c.0 = "uncaptured"; + let fut = call!(c); + println!("after call"); + fut.await; + println!("after await"); + } + println!(); + + // Precise capture &mut struct by move + { + let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") }; + let mut c = guidance!(async move || { + s.a = 2; + let w = &mut s.b; + w.0 = "fixed"; + }); + // `s` is still captured fully as `&mut S`. + let fut = call!(c); + println!("after call"); + fut.await; + println!("after await"); + } + println!(); + + // Precise capture struct, consume field + { + let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") }; + let c = guidance!(async move || { + // s.a = 2; // FIXME(async_closures): Figure out why this fails + drop(s.b); + }); + s.c.0 = "uncaptured"; + let fut = call!(c); + println!("after call"); + fut.await; + println!("after await"); + } + println!(); + + // Precise capture struct by move, consume field + { + let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") }; + let c = guidance!(async move || { + // s.a = 2; // FIXME(async_closures): Figure out why this fails + drop(s.b); + }); + s.c.0 = "uncaptured"; + let fut = call!(c); + println!("after call"); + fut.await; + println!("after await"); + } +} + +fn main() { + block_on::block_on(async_main()); +}