Skip to content

Commit 802fb2a

Browse files
Prune unreachable variants of coroutines
1 parent 3736b85 commit 802fb2a

File tree

4 files changed

+534
-0
lines changed

4 files changed

+534
-0
lines changed

compiler/rustc_mir_transform/src/simplify.rs

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
2828
//! return.
2929
30+
use rustc_abi::{FieldIdx, VariantIdx};
31+
use rustc_data_structures::fx::FxHashSet;
32+
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
33+
use rustc_index::bit_set::DenseBitSet;
3034
use rustc_index::{Idx, IndexSlice, IndexVec};
3135
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
3236
use rustc_middle::mir::*;
@@ -68,6 +72,7 @@ impl SimplifyCfg {
6872

6973
pub(super) fn simplify_cfg(body: &mut Body<'_>) {
7074
CfgSimplifier::new(body).simplify();
75+
remove_dead_coroutine_switch_variants(body);
7176
remove_dead_blocks(body);
7277

7378
// FIXME: Should probably be moved into some kind of pass manager
@@ -292,6 +297,168 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>)
292297
}
293298
}
294299

300+
const SELF_LOCAL: Local = Local::from_u32(1);
301+
const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0);
302+
303+
pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) {
304+
let Some(coroutine_layout) = body.coroutine_layout_raw() else {
305+
// Not a coroutine; no coroutine variants to remove.
306+
return;
307+
};
308+
309+
let bb0 = &body.basic_blocks[START_BLOCK];
310+
311+
let is_pinned = match body.coroutine_kind().unwrap() {
312+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false,
313+
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
314+
| CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
315+
| CoroutineKind::Coroutine(_) => true,
316+
};
317+
// This is essentially our off-brand `Underefer`. This stores the set of locals
318+
// that we have determined to contain references to the coroutine discriminant.
319+
// If the self type is not pinned, this is just going to be `_1`. However, if
320+
// the self type is pinned, the derefer will emit statements of the form:
321+
// _x = CopyForDeref (_1.0);
322+
// We'll store the local for `_x` so that we can later detect discriminant stores
323+
// of the form:
324+
// Discriminant((*_x)) = ...
325+
// which correspond to reachable variants of the coroutine.
326+
let mut discr_locals = if is_pinned {
327+
let Some(stmt) = bb0.statements.get(0) else {
328+
// The coroutine body may have been turned into a single `unreachable`.
329+
return;
330+
};
331+
// We match `CopyForDeref` (which is what gets emitted from the state transform
332+
// pass), but also we match *regular* `Copy`, which is what GVN may optimize it to.
333+
let StatementKind::Assign(box (
334+
place,
335+
Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place),
336+
)) = &stmt.kind
337+
else {
338+
panic!("The first statement of a coroutine is not a self deref");
339+
};
340+
let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } =
341+
deref_place.as_ref()
342+
else {
343+
panic!("The first statement of a coroutine is not a self deref");
344+
};
345+
FxHashSet::from_iter([place.as_local().unwrap()])
346+
} else {
347+
FxHashSet::from_iter([SELF_LOCAL])
348+
};
349+
350+
// The starting block of all coroutines is a switch for the coroutine variants.
351+
// This is preceded by a read of the discriminant. If we don't find this, then
352+
// we must have optimized away the switch, so bail.
353+
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) =
354+
&bb0.statements[if is_pinned { 1 } else { 0 }].kind
355+
else {
356+
// The following statement is not a discriminant read. We may have
357+
// optimized it out, so bail gracefully.
358+
return;
359+
};
360+
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref()
361+
else {
362+
// We expect the discriminant to have read `&mut self`,
363+
// so we expect the place to be a deref. If we didn't, then
364+
// it may have been optimized out, so bail gracefully.
365+
return;
366+
};
367+
if !discr_locals.contains(&deref_local) {
368+
// The place being read isn't `_1` (self) or a `Derefer`-inserted local.
369+
// It may have been optimized out, so bail gracefully.
370+
return;
371+
}
372+
let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind
373+
else {
374+
// When panic=abort, we may end up folding away the other variants of the
375+
// coroutine, and end up with ths `SwitchInt` getting replaced. In this
376+
// case, there's no need to do this optimization, so bail gracefully.
377+
return;
378+
};
379+
if place != discr_place {
380+
// Make sure we don't try to match on some other `SwitchInt`; we should be
381+
// matching on the discriminant we just read.
382+
return;
383+
}
384+
385+
let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
386+
let mut worklist = vec![];
387+
let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len());
388+
389+
// Insert unresumed (initial), returned, panicked variants.
390+
// We treat these as always reachable.
391+
visited_variants.insert(VariantIdx::from_usize(0));
392+
visited_variants.insert(VariantIdx::from_usize(1));
393+
visited_variants.insert(VariantIdx::from_usize(2));
394+
worklist.push(targets.target_for_value(0));
395+
worklist.push(targets.target_for_value(1));
396+
worklist.push(targets.target_for_value(2));
397+
398+
// Walk all of the reachable variant blocks.
399+
while let Some(block) = worklist.pop() {
400+
if !visited.insert(block) {
401+
continue;
402+
}
403+
404+
let data = &body.basic_blocks[block];
405+
for stmt in &data.statements {
406+
match &stmt.kind {
407+
// If we see a `SetDiscriminant` statement for our coroutine,
408+
// mark that variant as reachable and add it to the worklist.
409+
StatementKind::SetDiscriminant { place, variant_index } => {
410+
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } =
411+
(**place).as_ref()
412+
else {
413+
continue;
414+
};
415+
if !discr_locals.contains(&deref_local) {
416+
continue;
417+
}
418+
visited_variants.insert(*variant_index);
419+
worklist.push(targets.target_for_value(variant_index.as_u32().into()));
420+
}
421+
// The derefer may have inserted a local to access the variant.
422+
// Make sure we keep track of it here.
423+
StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => {
424+
if !is_pinned {
425+
continue;
426+
}
427+
let PlaceRef {
428+
local: SELF_LOCAL,
429+
projection: &[PlaceElem::Field(FIELD_ZERO, _)],
430+
} = deref_place.as_ref()
431+
else {
432+
continue;
433+
};
434+
discr_locals.insert(place.as_local().unwrap());
435+
}
436+
_ => {}
437+
}
438+
}
439+
440+
// Also walk all the successors of this block.
441+
if let Some(term) = &data.terminator {
442+
worklist.extend(term.successors());
443+
}
444+
}
445+
446+
// Filter out the variants that are unreachable.
447+
let TerminatorKind::SwitchInt { targets, .. } =
448+
&mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind
449+
else {
450+
unreachable!();
451+
};
452+
*targets = SwitchTargets::new(
453+
targets
454+
.iter()
455+
.filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))),
456+
targets.otherwise(),
457+
);
458+
459+
// FIXME: We could remove dead variant fields from the coroutine layout, too.
460+
}
461+
295462
pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
296463
let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
297464
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
- // MIR for `outer::{closure#0}` before SimplifyCfg-final
2+
+ // MIR for `outer::{closure#0}` after SimplifyCfg-final
3+
/* coroutine_layout = CoroutineLayout {
4+
field_tys: {
5+
_0: CoroutineSavedTy {
6+
ty: Coroutine(
7+
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
8+
[
9+
(),
10+
std::future::ResumeTy,
11+
(),
12+
(),
13+
CoroutineWitness(
14+
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
15+
[],
16+
),
17+
(),
18+
],
19+
),
20+
source_info: SourceInfo {
21+
span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16),
22+
scope: scope[0],
23+
},
24+
ignore_for_traits: false,
25+
},
26+
},
27+
variant_fields: {
28+
Unresumed(0): [],
29+
Returned (1): [],
30+
Panicked (2): [],
31+
Suspend0 (3): [_0],
32+
},
33+
storage_conflicts: BitMatrix(1x1) {
34+
(_0, _0),
35+
},
36+
} */
37+
38+
fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> {
39+
debug _task_context => _2;
40+
let mut _0: std::task::Poll<()>;
41+
let mut _3: {async fn body of inner()};
42+
let mut _4: {async fn body of inner()};
43+
let mut _5: std::task::Poll<()>;
44+
let mut _6: std::pin::Pin<&mut {async fn body of inner()}>;
45+
let mut _7: &mut {async fn body of inner()};
46+
let mut _8: &mut std::task::Context<'_>;
47+
let mut _9: isize;
48+
let mut _11: ();
49+
let mut _12: &mut std::task::Context<'_>;
50+
let mut _13: u32;
51+
let mut _14: &mut {async fn body of outer()};
52+
scope 1 {
53+
debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()});
54+
let _10: ();
55+
scope 2 {
56+
debug result => const ();
57+
}
58+
}
59+
60+
bb0: {
61+
_14 = copy (_1.0: &mut {async fn body of outer()});
62+
_13 = discriminant((*_14));
63+
- switchInt(move _13) -> [0: bb1, 1: bb15, 3: bb14, otherwise: bb8];
64+
+ switchInt(move _13) -> [0: bb2, 1: bb4, otherwise: bb1];
65+
}
66+
67+
bb1: {
68+
- nop;
69+
- goto -> bb12;
70+
- }
71+
-
72+
- bb2: {
73+
- StorageLive(_3);
74+
- StorageLive(_4);
75+
- _4 = inner() -> [return: bb3, unwind unreachable];
76+
- }
77+
-
78+
- bb3: {
79+
- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind unreachable];
80+
- }
81+
-
82+
- bb4: {
83+
- StorageDead(_4);
84+
- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3;
85+
- goto -> bb5;
86+
- }
87+
-
88+
- bb5: {
89+
- StorageLive(_5);
90+
- StorageLive(_6);
91+
- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()});
92+
- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind unreachable];
93+
- }
94+
-
95+
- bb6: {
96+
- nop;
97+
- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind unreachable];
98+
- }
99+
-
100+
- bb7: {
101+
- StorageDead(_6);
102+
- _9 = discriminant(_5);
103+
- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8];
104+
- }
105+
-
106+
- bb8: {
107+
unreachable;
108+
}
109+
110+
- bb9: {
111+
- StorageDead(_5);
112+
- _0 = const Poll::<()>::Pending;
113+
- StorageDead(_3);
114+
- discriminant((*_14)) = 3;
115+
- return;
116+
- }
117+
-
118+
- bb10: {
119+
- StorageLive(_10);
120+
- nop;
121+
- StorageDead(_10);
122+
- StorageDead(_5);
123+
- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind unreachable];
124+
- }
125+
-
126+
- bb11: {
127+
- StorageDead(_3);
128+
+ bb2: {
129+
_11 = const ();
130+
- goto -> bb13;
131+
+ goto -> bb3;
132+
}
133+
134+
- bb12: {
135+
- _11 = const ();
136+
- goto -> bb13;
137+
- }
138+
-
139+
- bb13: {
140+
+ bb3: {
141+
_0 = Poll::<()>::Ready(const ());
142+
discriminant((*_14)) = 1;
143+
return;
144+
}
145+
146+
- bb14: {
147+
- StorageLive(_3);
148+
- nop;
149+
- goto -> bb5;
150+
- }
151+
-
152+
- bb15: {
153+
- assert(const false, "`async fn` resumed after completion") -> [success: bb15, unwind unreachable];
154+
+ bb4: {
155+
+ assert(const false, "`async fn` resumed after completion") -> [success: bb4, unwind unreachable];
156+
}
157+
}
158+

0 commit comments

Comments
 (0)