Skip to content

Commit 8893940

Browse files
committed
!! (WIP) make Thir responsible for walking THIR patterns
1 parent 2d2a9ab commit 8893940

File tree

5 files changed

+33
-25
lines changed

5 files changed

+33
-25
lines changed

compiler/rustc_middle/src/thir.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,17 @@ impl<'tcx> Pat<'tcx> {
640640
_ => None,
641641
}
642642
}
643+
}
643644

645+
impl<'tcx> Thir<'tcx> {
644646
/// Call `f` on every "binding" in a pattern, e.g., on `a` in
645647
/// `match foo() { Some(a) => (), None => () }`
646-
pub fn each_binding(&self, mut f: impl FnMut(Symbol, ByRef, Ty<'tcx>, Span)) {
647-
self.walk_always(|p| {
648+
pub fn for_each_binding_in_pat(
649+
&self,
650+
pat: &Pat<'tcx>,
651+
mut f: impl FnMut(Symbol, ByRef, Ty<'tcx>, Span),
652+
) {
653+
self.walk_pat_always(pat, |p| {
648654
if let PatKind::Binding { name, mode, ty, .. } = p.kind {
649655
f(name, mode.0, ty, p.span);
650656
}
@@ -654,22 +660,22 @@ impl<'tcx> Pat<'tcx> {
654660
/// Walk the pattern in left-to-right order.
655661
///
656662
/// If `it(pat)` returns `false`, the children are not visited.
657-
pub fn walk(&self, mut it: impl FnMut(&Pat<'tcx>) -> bool) {
658-
self.walk_(&mut it)
663+
pub fn walk_pat(&self, pat: &Pat<'tcx>, mut it: impl FnMut(&Pat<'tcx>) -> bool) {
664+
self.walk_pat_inner(pat, &mut it)
659665
}
660666

661-
fn walk_(&self, it: &mut impl FnMut(&Pat<'tcx>) -> bool) {
662-
if !it(self) {
667+
fn walk_pat_inner(&self, pat: &Pat<'tcx>, it: &mut impl FnMut(&Pat<'tcx>) -> bool) {
668+
if !it(pat) {
663669
return;
664670
}
665671

666-
for_each_immediate_subpat(self, |p| p.walk_(it));
672+
for_each_immediate_subpat(self, pat, |p| self.walk_pat_inner(p, it));
667673
}
668674

669675
/// Whether the pattern has a `PatKind::Error` nested within.
670-
pub fn pat_error_reported(&self) -> Result<(), ErrorGuaranteed> {
676+
pub fn pat_error_reported(&self, pat: &Pat<'tcx>) -> Result<(), ErrorGuaranteed> {
671677
let mut error = None;
672-
self.walk(|pat| {
678+
self.walk_pat(pat, |pat| {
673679
if let PatKind::Error(e) = pat.kind
674680
&& error.is_none()
675681
{
@@ -686,23 +692,23 @@ impl<'tcx> Pat<'tcx> {
686692
/// Walk the pattern in left-to-right order.
687693
///
688694
/// If you always want to recurse, prefer this method over `walk`.
689-
pub fn walk_always(&self, mut it: impl FnMut(&Pat<'tcx>)) {
690-
self.walk(|p| {
695+
pub fn walk_pat_always(&self, pat: &Pat<'tcx>, mut it: impl FnMut(&Pat<'tcx>)) {
696+
self.walk_pat(pat, |p| {
691697
it(p);
692698
true
693699
})
694700
}
695701

696702
/// Whether this a never pattern.
697-
pub fn is_never_pattern(&self) -> bool {
703+
pub fn is_never_pattern(&self, pat: &Pat<'tcx>) -> bool {
698704
let mut is_never_pattern = false;
699-
self.walk(|pat| match &pat.kind {
705+
self.walk_pat(pat, |pat| match &pat.kind {
700706
PatKind::Never => {
701707
is_never_pattern = true;
702708
false
703709
}
704710
PatKind::Or { pats } => {
705-
is_never_pattern = pats.iter().all(|p| p.is_never_pattern());
711+
is_never_pattern = pats.iter().all(|p| self.is_never_pattern(p));
706712
false
707713
}
708714
_ => true,

compiler/rustc_middle/src/thir/visit.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ pub fn walk_pat<'thir, 'tcx: 'thir, V: Visitor<'thir, 'tcx>>(
237237
visitor: &mut V,
238238
pat: &'thir Pat<'tcx>,
239239
) {
240-
for_each_immediate_subpat(pat, |p| visitor.visit_pat(p));
240+
for_each_immediate_subpat(visitor.thir(), pat, |p| visitor.visit_pat(p));
241241
}
242242

243-
/// Invokes `callback` on each immediate subpattern of `pat`, if any.
244243
pub(crate) fn for_each_immediate_subpat<'a, 'tcx>(
244+
_thir: &'a Thir<'tcx>,
245245
pat: &'a Pat<'tcx>,
246246
mut callback: impl FnMut(&'a Pat<'tcx>),
247247
) {

compiler/rustc_mir_build/src/builder/matches/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
861861
pattern: &Pat<'tcx>,
862862
f: &mut impl FnMut(&mut Self, LocalVarId, Span),
863863
) {
864-
pattern.walk_always(|pat| {
864+
self.thir.walk_pat_always(pattern, |pat| {
865865
if let PatKind::Binding { var, is_primary: true, .. } = pat.kind {
866866
f(self, var, pat.span);
867867
}
@@ -1037,7 +1037,7 @@ impl<'tcx> FlatPat<'tcx> {
10371037
span: pattern.span,
10381038
bindings: Vec::new(),
10391039
ascriptions: Vec::new(),
1040-
is_never: pattern.is_never_pattern(),
1040+
is_never: cx.thir.is_never_pattern(pattern),
10411041
};
10421042
// Recursively remove irrefutable match pairs, while recording their
10431043
// bindings/ascriptions, and sort or-patterns after other match pairs.

compiler/rustc_mir_build/src/thir/pattern/check_match.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,14 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
277277
cx: &PatCtxt<'p, 'tcx>,
278278
pat: &'p Pat<'tcx>,
279279
) -> Result<&'p DeconstructedPat<'p, 'tcx>, ErrorGuaranteed> {
280-
if let Err(err) = pat.pat_error_reported() {
280+
if let Err(err) = cx.thir.pat_error_reported(pat) {
281281
self.error = Err(err);
282282
Err(err)
283283
} else {
284284
// Check the pattern for some things unrelated to exhaustiveness.
285285
let refutable = if cx.refutable { Refutable } else { Irrefutable };
286286
let mut err = Ok(());
287-
pat.walk_always(|pat| {
287+
cx.thir.walk_pat_always(pat, |pat| {
288288
check_borrow_conflicts_in_at_patterns(self, pat);
289289
check_for_bindings_named_same_as_variants(self, pat, refutable);
290290
err = err.and(check_never_pattern(cx, pat));
@@ -385,6 +385,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
385385
scrutinee.map(|scrut| self.is_known_valid_scrutinee(scrut)).unwrap_or(true);
386386
PatCtxt {
387387
tcx: self.tcx,
388+
thir: self.thir,
388389
typeck_results: self.typeck_results,
389390
typing_env: self.typing_env,
390391
module: self.tcx.parent_module(self.lint_level).to_def_id(),
@@ -704,7 +705,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
704705
&& scrut.is_some()
705706
{
706707
let mut bindings = vec![];
707-
pat.each_binding(|name, _, _, _| bindings.push(name));
708+
self.thir.for_each_binding_in_pat(pat, |name, _, _, _| bindings.push(name));
708709

709710
let semi_span = span.shrink_to_hi();
710711
let start_span = span.shrink_to_lo();
@@ -780,7 +781,7 @@ fn check_borrow_conflicts_in_at_patterns<'tcx>(cx: &MatchVisitor<'_, 'tcx>, pat:
780781
ByRef::No if is_binding_by_move(ty) => {
781782
// We have `x @ pat` where `x` is by-move. Reject all borrows in `pat`.
782783
let mut conflicts_ref = Vec::new();
783-
sub.each_binding(|_, mode, _, span| {
784+
cx.thir.for_each_binding_in_pat(sub, |_, mode, _, span| {
784785
if matches!(mode, ByRef::Yes(_)) {
785786
conflicts_ref.push(span)
786787
}
@@ -809,7 +810,7 @@ fn check_borrow_conflicts_in_at_patterns<'tcx>(cx: &MatchVisitor<'_, 'tcx>, pat:
809810
let mut conflicts_move = Vec::new();
810811
let mut conflicts_mut_mut = Vec::new();
811812
let mut conflicts_mut_ref = Vec::new();
812-
sub.each_binding(|name, mode, ty, span| {
813+
cx.thir.for_each_binding_in_pat(sub, |name, mode, ty, span| {
813814
match mode {
814815
ByRef::Yes(mut_inner) => match (mut_outer, mut_inner) {
815816
// Both sides are `ref`.

compiler/rustc_pattern_analysis/src/rustc.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_hir::def_id::DefId;
88
use rustc_index::{Idx, IndexVec};
99
use rustc_middle::middle::stability::EvalResult;
1010
use rustc_middle::mir::{self, Const};
11-
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary};
11+
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary, Thir};
1212
use rustc_middle::ty::layout::IntegerExt;
1313
use rustc_middle::ty::{
1414
self, FieldDef, OpaqueTypeKey, ScalarInt, Ty, TyCtxt, TypeVisitableExt, VariantDef,
@@ -76,8 +76,9 @@ impl<'tcx> RevealedTy<'tcx> {
7676
}
7777

7878
#[derive(Clone)]
79-
pub struct RustcPatCtxt<'p, 'tcx: 'p> {
79+
pub struct RustcPatCtxt<'p, 'tcx> {
8080
pub tcx: TyCtxt<'tcx>,
81+
pub thir: &'p Thir<'tcx>,
8182
pub typeck_results: &'tcx ty::TypeckResults<'tcx>,
8283
/// The module in which the match occurs. This is necessary for
8384
/// checking inhabited-ness of types because whether a type is (visibly)

0 commit comments

Comments
 (0)