From 802fb2acf1766fa9db5d864efa65c85919a18998 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 14 Jan 2025 06:04:19 +0000 Subject: [PATCH] Prune unreachable variants of coroutines --- compiler/rustc_mir_transform/src/simplify.rs | 167 +++++++++++++++ ...sure#0}.SimplifyCfg-final.panic-abort.diff | 158 ++++++++++++++ ...ure#0}.SimplifyCfg-final.panic-unwind.diff | 194 ++++++++++++++++++ tests/mir-opt/coroutine_dead_variants.rs | 15 ++ 4 files changed, 534 insertions(+) create mode 100644 tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-abort.diff create mode 100644 tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-unwind.diff create mode 100644 tests/mir-opt/coroutine_dead_variants.rs diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 4f312ed2aaabc..e210edc619f23 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -27,6 +27,10 @@ //! naively generate still contains the `_a = ()` write in the unreachable block "after" the //! return. +use rustc_abi::{FieldIdx, VariantIdx}; +use rustc_data_structures::fx::FxHashSet; +use rustc_hir::{CoroutineDesugaring, CoroutineKind}; +use rustc_index::bit_set::DenseBitSet; use rustc_index::{Idx, IndexSlice, IndexVec}; use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; @@ -68,6 +72,7 @@ impl SimplifyCfg { pub(super) fn simplify_cfg(body: &mut Body<'_>) { CfgSimplifier::new(body).simplify(); + remove_dead_coroutine_switch_variants(body); remove_dead_blocks(body); // FIXME: Should probably be moved into some kind of pass manager @@ -292,6 +297,168 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) } } +const SELF_LOCAL: Local = Local::from_u32(1); +const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0); + +pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) { + let Some(coroutine_layout) = body.coroutine_layout_raw() else { + // Not a coroutine; no coroutine variants to remove. + return; + }; + + let bb0 = &body.basic_blocks[START_BLOCK]; + + let is_pinned = match body.coroutine_kind().unwrap() { + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false, + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) + | CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) + | CoroutineKind::Coroutine(_) => true, + }; + // This is essentially our off-brand `Underefer`. This stores the set of locals + // that we have determined to contain references to the coroutine discriminant. + // If the self type is not pinned, this is just going to be `_1`. However, if + // the self type is pinned, the derefer will emit statements of the form: + // _x = CopyForDeref (_1.0); + // We'll store the local for `_x` so that we can later detect discriminant stores + // of the form: + // Discriminant((*_x)) = ... + // which correspond to reachable variants of the coroutine. + let mut discr_locals = if is_pinned { + let Some(stmt) = bb0.statements.get(0) else { + // The coroutine body may have been turned into a single `unreachable`. + return; + }; + // We match `CopyForDeref` (which is what gets emitted from the state transform + // pass), but also we match *regular* `Copy`, which is what GVN may optimize it to. + let StatementKind::Assign(box ( + place, + Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place), + )) = &stmt.kind + else { + panic!("The first statement of a coroutine is not a self deref"); + }; + let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } = + deref_place.as_ref() + else { + panic!("The first statement of a coroutine is not a self deref"); + }; + FxHashSet::from_iter([place.as_local().unwrap()]) + } else { + FxHashSet::from_iter([SELF_LOCAL]) + }; + + // The starting block of all coroutines is a switch for the coroutine variants. + // This is preceded by a read of the discriminant. If we don't find this, then + // we must have optimized away the switch, so bail. + let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) = + &bb0.statements[if is_pinned { 1 } else { 0 }].kind + else { + // The following statement is not a discriminant read. We may have + // optimized it out, so bail gracefully. + return; + }; + let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref() + else { + // We expect the discriminant to have read `&mut self`, + // so we expect the place to be a deref. If we didn't, then + // it may have been optimized out, so bail gracefully. + return; + }; + if !discr_locals.contains(&deref_local) { + // The place being read isn't `_1` (self) or a `Derefer`-inserted local. + // It may have been optimized out, so bail gracefully. + return; + } + let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind + else { + // When panic=abort, we may end up folding away the other variants of the + // coroutine, and end up with ths `SwitchInt` getting replaced. In this + // case, there's no need to do this optimization, so bail gracefully. + return; + }; + if place != discr_place { + // Make sure we don't try to match on some other `SwitchInt`; we should be + // matching on the discriminant we just read. + return; + } + + let mut visited = DenseBitSet::new_empty(body.basic_blocks.len()); + let mut worklist = vec![]; + let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len()); + + // Insert unresumed (initial), returned, panicked variants. + // We treat these as always reachable. + visited_variants.insert(VariantIdx::from_usize(0)); + visited_variants.insert(VariantIdx::from_usize(1)); + visited_variants.insert(VariantIdx::from_usize(2)); + worklist.push(targets.target_for_value(0)); + worklist.push(targets.target_for_value(1)); + worklist.push(targets.target_for_value(2)); + + // Walk all of the reachable variant blocks. + while let Some(block) = worklist.pop() { + if !visited.insert(block) { + continue; + } + + let data = &body.basic_blocks[block]; + for stmt in &data.statements { + match &stmt.kind { + // If we see a `SetDiscriminant` statement for our coroutine, + // mark that variant as reachable and add it to the worklist. + StatementKind::SetDiscriminant { place, variant_index } => { + let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = + (**place).as_ref() + else { + continue; + }; + if !discr_locals.contains(&deref_local) { + continue; + } + visited_variants.insert(*variant_index); + worklist.push(targets.target_for_value(variant_index.as_u32().into())); + } + // The derefer may have inserted a local to access the variant. + // Make sure we keep track of it here. + StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => { + if !is_pinned { + continue; + } + let PlaceRef { + local: SELF_LOCAL, + projection: &[PlaceElem::Field(FIELD_ZERO, _)], + } = deref_place.as_ref() + else { + continue; + }; + discr_locals.insert(place.as_local().unwrap()); + } + _ => {} + } + } + + // Also walk all the successors of this block. + if let Some(term) = &data.terminator { + worklist.extend(term.successors()); + } + } + + // Filter out the variants that are unreachable. + let TerminatorKind::SwitchInt { targets, .. } = + &mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind + else { + unreachable!(); + }; + *targets = SwitchTargets::new( + targets + .iter() + .filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))), + targets.otherwise(), + ); + + // FIXME: We could remove dead variant fields from the coroutine layout, too. +} + pub(super) fn remove_dead_blocks(body: &mut Body<'_>) { let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| { // CfgSimplifier::simplify leaves behind some unreachable basic blocks without a diff --git a/tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-abort.diff b/tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-abort.diff new file mode 100644 index 0000000000000..ee258239d7df8 --- /dev/null +++ b/tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-abort.diff @@ -0,0 +1,158 @@ +- // MIR for `outer::{closure#0}` before SimplifyCfg-final ++ // MIR for `outer::{closure#0}` after SimplifyCfg-final + /* coroutine_layout = CoroutineLayout { + field_tys: { + _0: CoroutineSavedTy { + ty: Coroutine( + DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}), + [ + (), + std::future::ResumeTy, + (), + (), + CoroutineWitness( + DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}), + [], + ), + (), + ], + ), + source_info: SourceInfo { + span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16), + scope: scope[0], + }, + ignore_for_traits: false, + }, + }, + variant_fields: { + Unresumed(0): [], + Returned (1): [], + Panicked (2): [], + Suspend0 (3): [_0], + }, + storage_conflicts: BitMatrix(1x1) { + (_0, _0), + }, + } */ + + fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> { + debug _task_context => _2; + let mut _0: std::task::Poll<()>; + let mut _3: {async fn body of inner()}; + let mut _4: {async fn body of inner()}; + let mut _5: std::task::Poll<()>; + let mut _6: std::pin::Pin<&mut {async fn body of inner()}>; + let mut _7: &mut {async fn body of inner()}; + let mut _8: &mut std::task::Context<'_>; + let mut _9: isize; + let mut _11: (); + let mut _12: &mut std::task::Context<'_>; + let mut _13: u32; + let mut _14: &mut {async fn body of outer()}; + scope 1 { + debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()}); + let _10: (); + scope 2 { + debug result => const (); + } + } + + bb0: { + _14 = copy (_1.0: &mut {async fn body of outer()}); + _13 = discriminant((*_14)); +- switchInt(move _13) -> [0: bb1, 1: bb15, 3: bb14, otherwise: bb8]; ++ switchInt(move _13) -> [0: bb2, 1: bb4, otherwise: bb1]; + } + + bb1: { +- nop; +- goto -> bb12; +- } +- +- bb2: { +- StorageLive(_3); +- StorageLive(_4); +- _4 = inner() -> [return: bb3, unwind unreachable]; +- } +- +- bb3: { +- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind unreachable]; +- } +- +- bb4: { +- StorageDead(_4); +- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3; +- goto -> bb5; +- } +- +- bb5: { +- StorageLive(_5); +- StorageLive(_6); +- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()}); +- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind unreachable]; +- } +- +- bb6: { +- nop; +- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind unreachable]; +- } +- +- bb7: { +- StorageDead(_6); +- _9 = discriminant(_5); +- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8]; +- } +- +- bb8: { + unreachable; + } + +- bb9: { +- StorageDead(_5); +- _0 = const Poll::<()>::Pending; +- StorageDead(_3); +- discriminant((*_14)) = 3; +- return; +- } +- +- bb10: { +- StorageLive(_10); +- nop; +- StorageDead(_10); +- StorageDead(_5); +- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind unreachable]; +- } +- +- bb11: { +- StorageDead(_3); ++ bb2: { + _11 = const (); +- goto -> bb13; ++ goto -> bb3; + } + +- bb12: { +- _11 = const (); +- goto -> bb13; +- } +- +- bb13: { ++ bb3: { + _0 = Poll::<()>::Ready(const ()); + discriminant((*_14)) = 1; + return; + } + +- bb14: { +- StorageLive(_3); +- nop; +- goto -> bb5; +- } +- +- bb15: { +- assert(const false, "`async fn` resumed after completion") -> [success: bb15, unwind unreachable]; ++ bb4: { ++ assert(const false, "`async fn` resumed after completion") -> [success: bb4, unwind unreachable]; + } + } + diff --git a/tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-unwind.diff b/tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-unwind.diff new file mode 100644 index 0000000000000..fcb2b3be7f844 --- /dev/null +++ b/tests/mir-opt/coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.panic-unwind.diff @@ -0,0 +1,194 @@ +- // MIR for `outer::{closure#0}` before SimplifyCfg-final ++ // MIR for `outer::{closure#0}` after SimplifyCfg-final + /* coroutine_layout = CoroutineLayout { + field_tys: { + _0: CoroutineSavedTy { + ty: Coroutine( + DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}), + [ + (), + std::future::ResumeTy, + (), + (), + CoroutineWitness( + DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}), + [], + ), + (), + ], + ), + source_info: SourceInfo { + span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16), + scope: scope[0], + }, + ignore_for_traits: false, + }, + }, + variant_fields: { + Unresumed(0): [], + Returned (1): [], + Panicked (2): [], + Suspend0 (3): [_0], + }, + storage_conflicts: BitMatrix(1x1) { + (_0, _0), + }, + } */ + + fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> { + debug _task_context => _2; + let mut _0: std::task::Poll<()>; + let mut _3: {async fn body of inner()}; + let mut _4: {async fn body of inner()}; + let mut _5: std::task::Poll<()>; + let mut _6: std::pin::Pin<&mut {async fn body of inner()}>; + let mut _7: &mut {async fn body of inner()}; + let mut _8: &mut std::task::Context<'_>; + let mut _9: isize; + let mut _11: (); + let mut _12: &mut std::task::Context<'_>; + let mut _13: u32; + let mut _14: &mut {async fn body of outer()}; + scope 1 { + debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()}); + let _10: (); + scope 2 { + debug result => const (); + } + } + + bb0: { + _14 = copy (_1.0: &mut {async fn body of outer()}); + _13 = discriminant((*_14)); +- switchInt(move _13) -> [0: bb1, 1: bb21, 2: bb20, 3: bb19, otherwise: bb8]; ++ switchInt(move _13) -> [0: bb2, 1: bb5, 2: bb4, otherwise: bb1]; + } + + bb1: { +- nop; +- goto -> bb12; +- } +- +- bb2: { +- StorageLive(_3); +- StorageLive(_4); +- _4 = inner() -> [return: bb3, unwind: bb17]; +- } +- +- bb3: { +- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind: bb17]; +- } +- +- bb4: { +- StorageDead(_4); +- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3; +- goto -> bb5; +- } +- +- bb5: { +- StorageLive(_5); +- StorageLive(_6); +- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()}); +- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind: bb15]; +- } +- +- bb6: { +- nop; +- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind: bb14]; +- } +- +- bb7: { +- StorageDead(_6); +- _9 = discriminant(_5); +- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8]; +- } +- +- bb8: { + unreachable; + } + +- bb9: { +- StorageDead(_5); +- _0 = const Poll::<()>::Pending; +- StorageDead(_3); +- discriminant((*_14)) = 3; +- return; +- } +- +- bb10: { +- StorageLive(_10); +- nop; +- StorageDead(_10); +- StorageDead(_5); +- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind: bb18]; +- } +- +- bb11: { +- StorageDead(_3); ++ bb2: { + _11 = const (); +- goto -> bb13; ++ goto -> bb3; + } + +- bb12: { +- _11 = const (); +- goto -> bb13; +- } +- +- bb13: { ++ bb3: { + _0 = Poll::<()>::Ready(const ()); + discriminant((*_14)) = 1; + return; + } + +- bb14 (cleanup): { +- StorageDead(_6); +- goto -> bb16; ++ bb4: { ++ assert(const false, "`async fn` resumed after panicking") -> [success: bb4, unwind continue]; + } + +- bb15 (cleanup): { +- StorageDead(_6); +- goto -> bb16; +- } +- +- bb16 (cleanup): { +- StorageDead(_5); +- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb18, unwind terminate(cleanup)]; +- } +- +- bb17 (cleanup): { +- StorageDead(_4); +- goto -> bb18; +- } +- +- bb18 (cleanup): { +- StorageDead(_3); +- discriminant((*_14)) = 2; +- resume; +- } +- +- bb19: { +- StorageLive(_3); +- nop; +- goto -> bb5; +- } +- +- bb20: { +- assert(const false, "`async fn` resumed after panicking") -> [success: bb20, unwind continue]; +- } +- +- bb21: { +- assert(const false, "`async fn` resumed after completion") -> [success: bb21, unwind continue]; +- } +- +- bb22 (cleanup): { +- resume; ++ bb5: { ++ assert(const false, "`async fn` resumed after completion") -> [success: bb5, unwind continue]; + } + } + diff --git a/tests/mir-opt/coroutine_dead_variants.rs b/tests/mir-opt/coroutine_dead_variants.rs new file mode 100644 index 0000000000000..e0960cf4884e7 --- /dev/null +++ b/tests/mir-opt/coroutine_dead_variants.rs @@ -0,0 +1,15 @@ +// skip-filecheck +// EMIT_MIR_FOR_EACH_PANIC_STRATEGY +//@ compile-flags: -Zmir-enable-passes=+GVN,+SimplifyLocals-after-value-numbering +//@ edition: 2021 + +async fn inner() { + panic!("disco"); +} + +// EMIT_MIR coroutine_dead_variants.outer-{closure#0}.SimplifyCfg-final.diff +async fn outer() { + if false { + inner().await; + } +}