Skip to content

Commit 376dd78

Browse files
Make sure that async closures (and fns) only capture their parent callable's parameters by move, and nothing else
1 parent 15c8c3a commit 376dd78

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

compiler/rustc_ast_lowering/src/item.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,10 +1281,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
12811281
};
12821282
let closure_id = coroutine_kind.closure_id();
12831283
let coroutine_expr = self.make_desugared_coroutine_expr(
1284-
// FIXME(async_closures): This should only move locals,
1285-
// and not upvars. Capturing closure upvars by ref doesn't
1286-
// work right now anyways, so whatever.
1287-
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1284+
// The default capture mode here is by-ref. Later on during upvar analysis,
1285+
// we will force the captured arguments to by-move, but for async closures,
1286+
// we want to make sure that we avoid unnecessarily moving captures, or else
1287+
// all async closures would default to `FnOnce` as their calling mode.
1288+
CaptureBy::Ref,
12881289
closure_id,
12891290
return_type_hint,
12901291
body_span,

compiler/rustc_hir_typeck/src/upvar.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,57 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
200200
capture_information: Default::default(),
201201
fake_reads: Default::default(),
202202
};
203+
204+
// As noted in `lower_coroutine_body_with_moved_arguments`, we default the capture mode
205+
// to `ByRef` for the `async {}` block internal to async fns/closure. This means
206+
// that we would *not* be moving all of the parameters into the async block by default.
207+
//
208+
// We force all of these arguments to be captured by move before we do expr use analysis.
209+
//
210+
// FIXME(async_closures): This could be cleaned up. It's a bit janky that we're just
211+
// moving all of the `LocalSource::AsyncFn` locals here.
212+
if let Some(hir::CoroutineKind::Desugared(
213+
_,
214+
hir::CoroutineSource::Fn | hir::CoroutineSource::Closure,
215+
)) = self.tcx.coroutine_kind(closure_def_id)
216+
{
217+
let hir::ExprKind::Block(block, _) = body.value.kind else {
218+
bug!();
219+
};
220+
for stmt in block.stmts {
221+
let hir::StmtKind::Local(hir::Local {
222+
init: Some(init),
223+
source: hir::LocalSource::AsyncFn,
224+
pat,
225+
..
226+
}) = stmt.kind
227+
else {
228+
bug!();
229+
};
230+
let hir::PatKind::Binding(hir::BindingAnnotation(hir::ByRef::No, _), _, _, _) =
231+
pat.kind
232+
else {
233+
// Complex pattern, skip the non-upvar local.
234+
continue;
235+
};
236+
let hir::ExprKind::Path(hir::QPath::Resolved(_, path)) = init.kind else {
237+
bug!();
238+
};
239+
let hir::def::Res::Local(local_id) = path.res else {
240+
bug!();
241+
};
242+
let place = self.place_for_root_variable(closure_def_id, local_id);
243+
delegate.capture_information.push((
244+
place,
245+
ty::CaptureInfo {
246+
capture_kind_expr_id: Some(init.hir_id),
247+
path_expr_id: Some(init.hir_id),
248+
capture_kind: UpvarCapture::ByValue,
249+
},
250+
));
251+
}
252+
}
253+
203254
euv::ExprUseVisitor::new(
204255
&mut delegate,
205256
&self.infcx,

0 commit comments

Comments
 (0)