60
60
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
61
61
//! we use this "by move" body instead.
62
62
63
- use itertools:: Itertools ;
64
-
65
- use rustc_data_structures:: unord:: UnordSet ;
63
+ use rustc_data_structures:: unord:: UnordMap ;
66
64
use rustc_hir as hir;
65
+ use rustc_middle:: hir:: place:: { Projection , ProjectionKind } ;
67
66
use rustc_middle:: mir:: visit:: MutVisitor ;
68
67
use rustc_middle:: mir:: { self , dump_mir, MirPass } ;
69
68
use rustc_middle:: ty:: { self , InstanceDef , Ty , TyCtxt , TypeVisitableExt } ;
70
- use rustc_target:: abi:: FieldIdx ;
69
+ use rustc_target:: abi:: { FieldIdx , VariantIdx } ;
71
70
72
71
pub struct ByMoveBody ;
73
72
@@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
116
115
. tuple_fields ( )
117
116
. len ( ) ;
118
117
119
- let mut by_ref_fields = UnordSet :: default ( ) ;
120
- for ( idx, ( coroutine_capture, parent_capture) ) in tcx
118
+ let mut field_remapping = UnordMap :: default ( ) ;
119
+
120
+ let mut parent_captures =
121
+ tcx. closure_captures ( parent_def_id) . iter ( ) . copied ( ) . enumerate ( ) . peekable ( ) ;
122
+
123
+ for ( child_field_idx, child_capture) in tcx
121
124
. closure_captures ( coroutine_def_id)
122
125
. iter ( )
126
+ . copied ( )
123
127
// By construction we capture all the args first.
124
128
. skip ( num_args)
125
- . zip_eq ( tcx. closure_captures ( parent_def_id) )
126
129
. enumerate ( )
127
130
{
128
- // This upvar is captured by-move from the parent closure, but by-ref
129
- // from the inner async block. That means that it's being borrowed from
130
- // the outer closure body -- we need to change the coroutine to take the
131
- // upvar by value.
132
- if coroutine_capture. is_by_ref ( ) && !parent_capture. is_by_ref ( ) {
133
- assert_ne ! (
134
- coroutine_kind,
135
- ty:: ClosureKind :: FnOnce ,
136
- "`FnOnce` coroutine-closures return coroutines that capture from \
137
- their body; it will always result in a borrowck error!"
131
+ loop {
132
+ let Some ( & ( parent_field_idx, parent_capture) ) = parent_captures. peek ( ) else {
133
+ bug ! ( "we ran out of parent captures!" )
134
+ } ;
135
+
136
+ if !std:: iter:: zip (
137
+ & child_capture. place . projections ,
138
+ & parent_capture. place . projections ,
139
+ )
140
+ . all ( |( child, parent) | child. kind == parent. kind )
141
+ {
142
+ // Skip this field.
143
+ let _ = parent_captures. next ( ) . unwrap ( ) ;
144
+ continue ;
145
+ }
146
+
147
+ let child_precise_captures =
148
+ & child_capture. place . projections [ parent_capture. place . projections . len ( ) ..] ;
149
+
150
+ let needs_deref = child_capture. is_by_ref ( ) && !parent_capture. is_by_ref ( ) ;
151
+ if needs_deref {
152
+ assert_ne ! (
153
+ coroutine_kind,
154
+ ty:: ClosureKind :: FnOnce ,
155
+ "`FnOnce` coroutine-closures return coroutines that capture from \
156
+ their body; it will always result in a borrowck error!"
157
+ ) ;
158
+ }
159
+
160
+ let mut parent_capture_ty = parent_capture. place . ty ( ) ;
161
+ parent_capture_ty = match parent_capture. info . capture_kind {
162
+ ty:: UpvarCapture :: ByValue => parent_capture_ty,
163
+ ty:: UpvarCapture :: ByRef ( kind) => Ty :: new_ref (
164
+ tcx,
165
+ tcx. lifetimes . re_erased ,
166
+ parent_capture_ty,
167
+ kind. to_mutbl_lossy ( ) ,
168
+ ) ,
169
+ } ;
170
+
171
+ field_remapping. insert (
172
+ FieldIdx :: from_usize ( child_field_idx + num_args) ,
173
+ (
174
+ FieldIdx :: from_usize ( parent_field_idx + num_args) ,
175
+ parent_capture_ty,
176
+ needs_deref,
177
+ child_precise_captures,
178
+ ) ,
138
179
) ;
139
- by_ref_fields. insert ( FieldIdx :: from_usize ( num_args + idx) ) ;
180
+
181
+ break ;
140
182
}
183
+ }
141
184
142
- // Make sure we're actually talking about the same capture.
143
- // FIXME(async_closures): We could look at the `hir::Upvar` instead?
144
- assert_eq ! ( coroutine_capture . place . ty ( ) , parent_capture . place . ty ( ) ) ;
185
+ if coroutine_kind == ty :: ClosureKind :: FnOnce {
186
+ assert_eq ! ( field_remapping . len ( ) , tcx . closure_captures ( parent_def_id ) . len ( ) ) ;
187
+ return ;
145
188
}
146
189
147
190
let by_move_coroutine_ty = tcx
@@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
157
200
) ;
158
201
159
202
let mut by_move_body = body. clone ( ) ;
160
- MakeByMoveBody { tcx, by_ref_fields , by_move_coroutine_ty } . visit_body ( & mut by_move_body) ;
203
+ MakeByMoveBody { tcx, field_remapping , by_move_coroutine_ty } . visit_body ( & mut by_move_body) ;
161
204
dump_mir ( tcx, false , "coroutine_by_move" , & 0 , & by_move_body, |_, _| Ok ( ( ) ) ) ;
162
205
by_move_body. source = mir:: MirSource :: from_instance ( InstanceDef :: CoroutineKindShim {
163
206
coroutine_def_id : coroutine_def_id. to_def_id ( ) ,
@@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
168
211
169
212
struct MakeByMoveBody < ' tcx > {
170
213
tcx : TyCtxt < ' tcx > ,
171
- by_ref_fields : UnordSet < FieldIdx > ,
214
+ field_remapping : UnordMap < FieldIdx , ( FieldIdx , Ty < ' tcx > , bool , & ' tcx [ Projection < ' tcx > ] ) > ,
172
215
by_move_coroutine_ty : Ty < ' tcx > ,
173
216
}
174
217
@@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
184
227
location : mir:: Location ,
185
228
) {
186
229
if place. local == ty:: CAPTURE_STRUCT_LOCAL
187
- && let Some ( ( & mir:: ProjectionElem :: Field ( idx, ty ) , projection) ) =
230
+ && let Some ( ( & mir:: ProjectionElem :: Field ( idx, _ ) , projection) ) =
188
231
place. projection . split_first ( )
189
- && self . by_ref_fields . contains ( & idx)
232
+ && let Some ( & ( remapped_idx, remapped_ty, needs_deref, additional_projections) ) =
233
+ self . field_remapping . get ( & idx)
190
234
{
191
- let ( begin, end) = projection. split_first ( ) . unwrap ( ) ;
192
- // FIXME(async_closures): I'm actually a bit surprised to see that we always
193
- // initially deref the by-ref upvars. If this is not actually true, then we
194
- // will at least get an ICE that explains why this isn't true :^)
195
- assert_eq ! ( * begin, mir:: ProjectionElem :: Deref ) ;
196
- // Peel one ref off of the ty.
197
- let peeled_ty = ty. builtin_deref ( true ) . unwrap ( ) . ty ;
235
+ let final_deref = if needs_deref {
236
+ let Some ( ( mir:: ProjectionElem :: Deref , rest) ) = projection. split_first ( ) else {
237
+ bug ! ( ) ;
238
+ } ;
239
+ rest
240
+ } else {
241
+ projection
242
+ } ;
243
+
244
+ let additional_projections =
245
+ additional_projections. iter ( ) . map ( |elem| match elem. kind {
246
+ ProjectionKind :: Deref => mir:: ProjectionElem :: Deref ,
247
+ ProjectionKind :: Field ( idx, VariantIdx :: ZERO ) => {
248
+ mir:: ProjectionElem :: Field ( idx, elem. ty )
249
+ }
250
+ _ => unreachable ! ( "precise captures only through fields and derefs" ) ,
251
+ } ) ;
252
+
198
253
* place = mir:: Place {
199
254
local : place. local ,
200
255
projection : self . tcx . mk_place_elems_from_iter (
201
- [ mir:: ProjectionElem :: Field ( idx , peeled_ty ) ]
256
+ [ mir:: ProjectionElem :: Field ( remapped_idx , remapped_ty ) ]
202
257
. into_iter ( )
203
- . chain ( end. iter ( ) . copied ( ) ) ,
258
+ . chain ( additional_projections)
259
+ . chain ( final_deref. iter ( ) . copied ( ) ) ,
204
260
) ,
205
261
} ;
206
262
}
0 commit comments