Skip to content

Commit 66697d5

Browse files
committed
Auto merge of #141581 - lcnr:fold-clauses, r=<try>
add additional `TypeFlags` fast paths Some crates, e.g. `diesel`, have items with a lot of where-clauses (more than 150). In these cases checking the `TypeFlags` of the whole `param_env` can be very beneficial. This adds `fn fold_clauses` to mirror the existing `fn visit_clauses` and then uses this in folders which fold `ParamEnv`s. Split out from #141451, depends on #141500. r? `@compiler-errors`
2 parents 95a2212 + 1a2a21a commit 66697d5

File tree

16 files changed

+197
-26
lines changed

16 files changed

+197
-26
lines changed

compiler/rustc_infer/src/infer/canonical/canonicalizer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,14 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
493493
ct
494494
}
495495
}
496+
497+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
498+
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
499+
}
500+
501+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
502+
if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c }
503+
}
496504
}
497505

498506
impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {

compiler/rustc_infer/src/infer/resolve.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
5555
ct.super_fold_with(self)
5656
}
5757
}
58+
59+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
60+
if !p.has_non_region_infer() { p } else { p.super_fold_with(self) }
61+
}
62+
63+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
64+
if !c.has_non_region_infer() { c } else { c.super_fold_with(self) }
65+
}
5866
}
5967

6068
/// The opportunistic region resolver opportunistically resolves regions

compiler/rustc_middle/src/ty/erase_regions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,12 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for RegionEraserVisitor<'tcx> {
8686
p
8787
}
8888
}
89+
90+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
91+
if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) {
92+
c.super_fold_with(self)
93+
} else {
94+
c
95+
}
96+
}
8997
}

compiler/rustc_middle/src/ty/fold.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ where
177177
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
178178
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
179179
}
180+
181+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
182+
if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c }
183+
}
180184
}
181185

182186
impl<'tcx> TyCtxt<'tcx> {

compiler/rustc_middle/src/ty/predicate.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ impl<'tcx> Clause<'tcx> {
238238
}
239239
}
240240

241+
impl<'tcx> rustc_type_ir::inherent::Clauses<TyCtxt<'tcx>> for ty::Clauses<'tcx> {}
242+
241243
#[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)]
242244
impl<'tcx> ExistentialPredicate<'tcx> {
243245
/// Compares via an ordering that will not change if modules are reordered or other changes are

compiler/rustc_middle/src/ty/structural_impls.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clause<'tcx> {
570570
}
571571
}
572572

573+
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
574+
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
575+
self,
576+
folder: &mut F,
577+
) -> Result<Self, F::Error> {
578+
folder.try_fold_clauses(self)
579+
}
580+
581+
fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
582+
folder.fold_clauses(self)
583+
}
584+
}
585+
573586
impl<'tcx> TypeVisitable<TyCtxt<'tcx>> for ty::Predicate<'tcx> {
574587
fn visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::Result {
575588
visitor.visit_predicate(*self)
@@ -615,6 +628,19 @@ impl<'tcx> TypeSuperVisitable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
615628
}
616629
}
617630

631+
impl<'tcx> TypeSuperFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
632+
fn try_super_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
633+
self,
634+
folder: &mut F,
635+
) -> Result<Self, F::Error> {
636+
ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
637+
}
638+
639+
fn super_fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
640+
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
641+
}
642+
}
643+
618644
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Const<'tcx> {
619645
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
620646
self,
@@ -775,7 +801,6 @@ macro_rules! list_fold {
775801
}
776802

777803
list_fold! {
778-
ty::Clauses<'tcx> : mk_clauses,
779804
&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
780805
&'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
781806
&'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,

compiler/rustc_next_trait_solver/src/canonicalizer.rs

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,22 @@ use rustc_type_ir::data_structures::{HashMap, ensure_sufficient_stack};
44
use rustc_type_ir::inherent::*;
55
use rustc_type_ir::solve::{Goal, QueryInput};
66
use rustc_type_ir::{
7-
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, InferCtxtLike, Interner,
8-
TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
7+
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, Flags, InferCtxtLike, Interner,
8+
TypeFlags, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
99
};
1010

1111
use crate::delegate::SolverDelegate;
1212

13+
/// Does this have infer/placeholder/param, free regions or ReErased?
14+
const NEEDS_CANONICAL: TypeFlags = TypeFlags::from_bits(
15+
TypeFlags::HAS_INFER.bits()
16+
| TypeFlags::HAS_PLACEHOLDER.bits()
17+
| TypeFlags::HAS_PARAM.bits()
18+
| TypeFlags::HAS_FREE_REGIONS.bits()
19+
| TypeFlags::HAS_RE_ERASED.bits(),
20+
)
21+
.unwrap();
22+
1323
/// Whether we're canonicalizing a query input or the query response.
1424
///
1525
/// When canonicalizing an input we're in the context of the caller
@@ -79,7 +89,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
7989
cache: Default::default(),
8090
};
8191

82-
let value = value.fold_with(&mut canonicalizer);
92+
let value = if value.has_type_flags(NEEDS_CANONICAL) {
93+
value.fold_with(&mut canonicalizer)
94+
} else {
95+
value
96+
};
8397
assert!(!value.has_infer(), "unexpected infer in {value:?}");
8498
assert!(!value.has_placeholders(), "unexpected placeholders in {value:?}");
8599
let (max_universe, variables) = canonicalizer.finalize();
@@ -111,7 +125,14 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
111125

112126
cache: Default::default(),
113127
};
114-
let param_env = input.goal.param_env.fold_with(&mut env_canonicalizer);
128+
129+
let param_env = input.goal.param_env;
130+
let param_env = if param_env.has_type_flags(NEEDS_CANONICAL) {
131+
param_env.fold_with(&mut env_canonicalizer)
132+
} else {
133+
param_env
134+
};
135+
115136
debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST);
116137
// Then canonicalize the rest of the input without keeping `'static`
117138
// while *mostly* reusing the canonicalizer from above.
@@ -134,10 +155,22 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
134155
cache: Default::default(),
135156
};
136157

137-
let predicate = input.goal.predicate.fold_with(&mut rest_canonicalizer);
158+
let predicate = input.goal.predicate;
159+
let predicate = if predicate.has_type_flags(NEEDS_CANONICAL) {
160+
predicate.fold_with(&mut rest_canonicalizer)
161+
} else {
162+
predicate
163+
};
138164
let goal = Goal { param_env, predicate };
165+
166+
let predefined_opaques_in_body = input.predefined_opaques_in_body;
139167
let predefined_opaques_in_body =
140-
input.predefined_opaques_in_body.fold_with(&mut rest_canonicalizer);
168+
if input.predefined_opaques_in_body.has_type_flags(NEEDS_CANONICAL) {
169+
predefined_opaques_in_body.fold_with(&mut rest_canonicalizer)
170+
} else {
171+
predefined_opaques_in_body
172+
};
173+
141174
let value = QueryInput { goal, predefined_opaques_in_body };
142175

143176
assert!(!value.has_infer(), "unexpected infer in {value:?}");
@@ -387,7 +420,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
387420
| ty::Alias(_, _)
388421
| ty::Bound(_, _)
389422
| ty::Error(_) => {
390-
return ensure_sufficient_stack(|| t.super_fold_with(self));
423+
return if t.has_type_flags(NEEDS_CANONICAL) {
424+
ensure_sufficient_stack(|| t.super_fold_with(self))
425+
} else {
426+
t
427+
};
391428
}
392429
};
393430

@@ -522,11 +559,28 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
522559
| ty::ConstKind::Unevaluated(_)
523560
| ty::ConstKind::Value(_)
524561
| ty::ConstKind::Error(_)
525-
| ty::ConstKind::Expr(_) => return c.super_fold_with(self),
562+
| ty::ConstKind::Expr(_) => {
563+
return if c.has_type_flags(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c };
564+
}
526565
};
527566

528567
let var = self.get_or_insert_bound_var(c, kind);
529568

530569
Const::new_anon_bound(self.cx(), self.binder_index, var)
531570
}
571+
572+
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
573+
if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
574+
}
575+
576+
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
577+
match self.canonicalize_mode {
578+
CanonicalizeMode::Input { keep_static: true }
579+
| CanonicalizeMode::Response { max_input_universe: _ } => {}
580+
CanonicalizeMode::Input { keep_static: false } => {
581+
panic!("erasing 'static in env")
582+
}
583+
}
584+
if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }
585+
}
532586
}

compiler/rustc_next_trait_solver/src/resolve.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::delegate::SolverDelegate;
1111
// EAGER RESOLUTION
1212

1313
/// Resolves ty, region, and const vars to their inferred values or their root vars.
14-
pub struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
14+
struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
1515
where
1616
D: SolverDelegate<Interner = I>,
1717
I: Interner,
@@ -22,8 +22,20 @@ where
2222
cache: DelayedMap<I::Ty, I::Ty>,
2323
}
2424

25+
pub fn eager_resolve_vars<D: SolverDelegate, T: TypeFoldable<D::Interner>>(
26+
delegate: &D,
27+
value: T,
28+
) -> T {
29+
if value.has_infer() {
30+
let mut folder = EagerResolver::new(delegate);
31+
value.fold_with(&mut folder)
32+
} else {
33+
value
34+
}
35+
}
36+
2537
impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
26-
pub fn new(delegate: &'a D) -> Self {
38+
fn new(delegate: &'a D) -> Self {
2739
EagerResolver { delegate, cache: Default::default() }
2840
}
2941
}
@@ -86,4 +98,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
8698
}
8799
}
88100
}
101+
102+
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
103+
if p.has_infer() { p.super_fold_with(self) } else { p }
104+
}
105+
106+
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
107+
if c.has_infer() { c.super_fold_with(self) } else { c }
108+
}
89109
}

compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use tracing::{debug, instrument, trace};
2222

2323
use crate::canonicalizer::Canonicalizer;
2424
use crate::delegate::SolverDelegate;
25-
use crate::resolve::EagerResolver;
25+
use crate::resolve::eager_resolve_vars;
2626
use crate::solve::eval_ctxt::CurrentGoalKind;
2727
use crate::solve::{
2828
CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal,
@@ -61,8 +61,7 @@ where
6161
// so we only canonicalize the lookup table and ignore
6262
// duplicate entries.
6363
let opaque_types = self.delegate.clone_opaque_types_lookup_table();
64-
let (goal, opaque_types) =
65-
(goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate));
64+
let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types));
6665

6766
let mut orig_values = Default::default();
6867
let canonical = Canonicalizer::canonicalize_input(
@@ -157,8 +156,8 @@ where
157156

158157
let external_constraints =
159158
self.compute_external_query_constraints(certainty, normalization_nested_goals);
160-
let (var_values, mut external_constraints) = (self.var_values, external_constraints)
161-
.fold_with(&mut EagerResolver::new(self.delegate));
159+
let (var_values, mut external_constraints) =
160+
eager_resolve_vars(self.delegate, (self.var_values, external_constraints));
162161

163162
// Remove any trivial or duplicated region constraints once we've resolved regions
164163
let mut unique = HashSet::default();
@@ -469,7 +468,7 @@ where
469468
{
470469
let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) };
471470
let state = inspect::State { var_values, data };
472-
let state = state.fold_with(&mut EagerResolver::new(delegate));
471+
let state = eager_resolve_vars(delegate, state);
473472
Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state)
474473
}
475474

compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,22 @@ where
848848
}
849849
}
850850
}
851+
852+
fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result {
853+
if p.has_non_region_infer() || p.has_placeholders() {
854+
p.super_visit_with(self)
855+
} else {
856+
ControlFlow::Continue(())
857+
}
858+
}
859+
860+
fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
861+
if c.has_non_region_infer() || c.has_placeholders() {
862+
c.super_visit_with(self)
863+
} else {
864+
ControlFlow::Continue(())
865+
}
866+
}
851867
}
852868

853869
let mut visitor = ContainsTermOrNotNameable {

compiler/rustc_trait_selection/src/solve/inspect/analyse.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
1515
use rustc_macros::extension;
1616
use rustc_middle::traits::ObligationCause;
1717
use rustc_middle::traits::solve::{Certainty, Goal, GoalSource, NoSolution, QueryResult};
18-
use rustc_middle::ty::{TyCtxt, TypeFoldable, VisitorResult, try_visit};
18+
use rustc_middle::ty::{TyCtxt, VisitorResult, try_visit};
1919
use rustc_middle::{bug, ty};
20-
use rustc_next_trait_solver::resolve::EagerResolver;
20+
use rustc_next_trait_solver::resolve::eager_resolve_vars;
2121
use rustc_next_trait_solver::solve::inspect::{self, instantiate_canonical_state};
2222
use rustc_next_trait_solver::solve::{GenerateProofTree, MaybeCause, SolverDelegateEvalExt as _};
2323
use rustc_span::{DUMMY_SP, Span};
@@ -187,8 +187,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
187187
let _ = term_hack.constrain(infcx, span, param_env);
188188
}
189189

190-
let opt_impl_args =
191-
opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx)));
190+
let opt_impl_args = opt_impl_args.map(|impl_args| eager_resolve_vars(infcx, impl_args));
192191

193192
let goals = instantiated_goals
194193
.into_iter()
@@ -392,7 +391,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
392391
infcx,
393392
depth,
394393
orig_values,
395-
goal: uncanonicalized_goal.fold_with(&mut EagerResolver::new(infcx)),
394+
goal: eager_resolve_vars(infcx, uncanonicalized_goal),
396395
result,
397396
evaluation_kind: evaluation.kind,
398397
normalizes_to_term_hack,

compiler/rustc_type_ir/src/binder.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,14 @@ impl<'a, I: Interner> TypeFolder<I> for ArgFolder<'a, I> {
711711
c.super_fold_with(self)
712712
}
713713
}
714+
715+
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
716+
if p.has_param() { p.super_fold_with(self) } else { p }
717+
}
718+
719+
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
720+
if c.has_param() { c.super_fold_with(self) } else { c }
721+
}
714722
}
715723

716724
impl<'a, I: Interner> ArgFolder<'a, I> {

0 commit comments

Comments
 (0)