Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 3674032

Browse files
Rework the ByMoveBody shim to actually work correctly
1 parent 1921968 commit 3674032

File tree

5 files changed

+334
-34
lines changed

5 files changed

+334
-34
lines changed

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

Lines changed: 90 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,13 @@
6060
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
6161
//! we use this "by move" body instead.
6262
63-
use itertools::Itertools;
64-
65-
use rustc_data_structures::unord::UnordSet;
63+
use rustc_data_structures::unord::UnordMap;
6664
use rustc_hir as hir;
65+
use rustc_middle::hir::place::{Projection, ProjectionKind};
6766
use rustc_middle::mir::visit::MutVisitor;
6867
use rustc_middle::mir::{self, dump_mir, MirPass};
6968
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
70-
use rustc_target::abi::FieldIdx;
69+
use rustc_target::abi::{FieldIdx, VariantIdx};
7170

7271
pub struct ByMoveBody;
7372

@@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
116115
.tuple_fields()
117116
.len();
118117

119-
let mut by_ref_fields = UnordSet::default();
120-
for (idx, (coroutine_capture, parent_capture)) in tcx
118+
let mut field_remapping = UnordMap::default();
119+
120+
let mut parent_captures =
121+
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
122+
123+
for (child_field_idx, child_capture) in tcx
121124
.closure_captures(coroutine_def_id)
122125
.iter()
126+
.copied()
123127
// By construction we capture all the args first.
124128
.skip(num_args)
125-
.zip_eq(tcx.closure_captures(parent_def_id))
126129
.enumerate()
127130
{
128-
// This upvar is captured by-move from the parent closure, but by-ref
129-
// from the inner async block. That means that it's being borrowed from
130-
// the outer closure body -- we need to change the coroutine to take the
131-
// upvar by value.
132-
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
133-
assert_ne!(
134-
coroutine_kind,
135-
ty::ClosureKind::FnOnce,
136-
"`FnOnce` coroutine-closures return coroutines that capture from \
137-
their body; it will always result in a borrowck error!"
131+
loop {
132+
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
133+
bug!("we ran out of parent captures!")
134+
};
135+
136+
if !std::iter::zip(
137+
&child_capture.place.projections,
138+
&parent_capture.place.projections,
139+
)
140+
.all(|(child, parent)| child.kind == parent.kind)
141+
{
142+
// Skip this field.
143+
let _ = parent_captures.next().unwrap();
144+
continue;
145+
}
146+
147+
let child_precise_captures =
148+
&child_capture.place.projections[parent_capture.place.projections.len()..];
149+
150+
let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
151+
if needs_deref {
152+
assert_ne!(
153+
coroutine_kind,
154+
ty::ClosureKind::FnOnce,
155+
"`FnOnce` coroutine-closures return coroutines that capture from \
156+
their body; it will always result in a borrowck error!"
157+
);
158+
}
159+
160+
let mut parent_capture_ty = parent_capture.place.ty();
161+
parent_capture_ty = match parent_capture.info.capture_kind {
162+
ty::UpvarCapture::ByValue => parent_capture_ty,
163+
ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
164+
tcx,
165+
tcx.lifetimes.re_erased,
166+
parent_capture_ty,
167+
kind.to_mutbl_lossy(),
168+
),
169+
};
170+
171+
field_remapping.insert(
172+
FieldIdx::from_usize(child_field_idx + num_args),
173+
(
174+
FieldIdx::from_usize(parent_field_idx + num_args),
175+
parent_capture_ty,
176+
needs_deref,
177+
child_precise_captures,
178+
),
138179
);
139-
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
180+
181+
break;
140182
}
183+
}
141184

142-
// Make sure we're actually talking about the same capture.
143-
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
144-
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
185+
if coroutine_kind == ty::ClosureKind::FnOnce {
186+
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
187+
return;
145188
}
146189

147190
let by_move_coroutine_ty = tcx
@@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
157200
);
158201

159202
let mut by_move_body = body.clone();
160-
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
203+
MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
161204
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
162205
by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
163206
coroutine_def_id: coroutine_def_id.to_def_id(),
@@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
168211

169212
struct MakeByMoveBody<'tcx> {
170213
tcx: TyCtxt<'tcx>,
171-
by_ref_fields: UnordSet<FieldIdx>,
214+
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,
172215
by_move_coroutine_ty: Ty<'tcx>,
173216
}
174217

@@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
184227
location: mir::Location,
185228
) {
186229
if place.local == ty::CAPTURE_STRUCT_LOCAL
187-
&& let Some((&mir::ProjectionElem::Field(idx, ty), projection)) =
230+
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
188231
place.projection.split_first()
189-
&& self.by_ref_fields.contains(&idx)
232+
&& let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
233+
self.field_remapping.get(&idx)
190234
{
191-
let (begin, end) = projection.split_first().unwrap();
192-
// FIXME(async_closures): I'm actually a bit surprised to see that we always
193-
// initially deref the by-ref upvars. If this is not actually true, then we
194-
// will at least get an ICE that explains why this isn't true :^)
195-
assert_eq!(*begin, mir::ProjectionElem::Deref);
196-
// Peel one ref off of the ty.
197-
let peeled_ty = ty.builtin_deref(true).unwrap().ty;
235+
let final_deref = if needs_deref {
236+
let Some((mir::ProjectionElem::Deref, rest)) = projection.split_first() else {
237+
bug!();
238+
};
239+
rest
240+
} else {
241+
projection
242+
};
243+
244+
let additional_projections =
245+
additional_projections.iter().map(|elem| match elem.kind {
246+
ProjectionKind::Deref => mir::ProjectionElem::Deref,
247+
ProjectionKind::Field(idx, VariantIdx::ZERO) => {
248+
mir::ProjectionElem::Field(idx, elem.ty)
249+
}
250+
_ => unreachable!("precise captures only through fields and derefs"),
251+
});
252+
198253
*place = mir::Place {
199254
local: place.local,
200255
projection: self.tcx.mk_place_elems_from_iter(
201-
[mir::ProjectionElem::Field(idx, peeled_ty)]
256+
[mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
202257
.into_iter()
203-
.chain(end.iter().copied()),
258+
.chain(additional_projections)
259+
.chain(final_deref.iter().copied()),
204260
),
205261
};
206262
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
after call
2+
after await
3+
fixed
4+
uncaptured
5+
6+
after call
7+
after await
8+
fixed
9+
uncaptured
10+
11+
after call
12+
after await
13+
fixed
14+
uncaptured
15+
16+
after call
17+
after await
18+
fixed
19+
untouched
20+
21+
after call
22+
drop first
23+
after await
24+
uncaptured
25+
26+
after call
27+
drop first
28+
after await
29+
uncaptured
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
after call
2+
after await
3+
fixed
4+
uncaptured
5+
6+
after call
7+
after await
8+
fixed
9+
uncaptured
10+
11+
after call
12+
fixed
13+
after await
14+
uncaptured
15+
16+
after call
17+
after await
18+
fixed
19+
untouched
20+
21+
after call
22+
drop first
23+
after await
24+
uncaptured
25+
26+
after call
27+
drop first
28+
after await
29+
uncaptured
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
after call
2+
after await
3+
fixed
4+
uncaptured
5+
6+
after call
7+
after await
8+
fixed
9+
uncaptured
10+
11+
after call
12+
fixed
13+
after await
14+
uncaptured
15+
16+
after call
17+
after await
18+
fixed
19+
untouched
20+
21+
after call
22+
drop first
23+
after await
24+
uncaptured
25+
26+
after call
27+
drop first
28+
after await
29+
uncaptured

0 commit comments

Comments
 (0)