From ab9d69fe615a736019b2e2b8f9ef4d6ac8a23509 Mon Sep 17 00:00:00 2001 From: lcnr Date: Sun, 27 Apr 2025 22:28:03 +0000 Subject: [PATCH 1/5] eagerly compute `sub_relations` again While still only using them for diagnostics. We could use them for cycle detection in generalization and it seems desirable to do so in the future. However, this is unsound with the old trait solver as its cache does not track these `sub_relations` in any way. We would also need to consider them when canonicalizing as otherwise instantiating the canonical response may fail. --- compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs | 7 -- compiler/rustc_infer/src/infer/context.rs | 4 + compiler/rustc_infer/src/infer/mod.rs | 5 ++ .../src/infer/relate/generalize.rs | 4 + .../src/infer/snapshot/undo_log.rs | 4 +- .../rustc_infer/src/infer/type_variable.rs | 79 +++++++++++++++++- .../src/solve/eval_ctxt/mod.rs | 4 + .../rustc_next_trait_solver/src/solve/mod.rs | 14 ++-- .../src/error_reporting/infer/mod.rs | 1 - .../error_reporting/infer/need_type_info.rs | 2 +- .../error_reporting/infer/sub_relations.rs | 81 ------------------- .../src/error_reporting/mod.rs | 4 - .../src/error_reporting/traits/mod.rs | 4 - compiler/rustc_type_ir/src/infer_ctxt.rs | 1 + 14 files changed, 107 insertions(+), 107 deletions(-) delete mode 100644 compiler/rustc_trait_selection/src/error_reporting/infer/sub_relations.rs diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs index de189b301092..54089b95bb4f 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs @@ -19,7 +19,6 @@ use rustc_middle::ty::{self, Const, Ty, TyCtxt, TypeVisitableExt}; use rustc_session::Session; use rustc_span::{self, DUMMY_SP, ErrorGuaranteed, Ident, Span, sym}; use rustc_trait_selection::error_reporting::TypeErrCtxt; -use rustc_trait_selection::error_reporting::infer::sub_relations::SubRelations; use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode, ObligationCtxt}; use crate::coercion::DynamicCoerceMany; @@ -177,14 +176,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { /// /// [`InferCtxtErrorExt::err_ctxt`]: rustc_trait_selection::error_reporting::InferCtxtErrorExt::err_ctxt pub(crate) fn err_ctxt(&'a self) -> TypeErrCtxt<'a, 'tcx> { - let mut sub_relations = SubRelations::default(); - sub_relations.add_constraints( - self, - self.fulfillment_cx.borrow_mut().pending_obligations().iter().map(|o| o.predicate), - ); TypeErrCtxt { infcx: &self.infcx, - sub_relations: RefCell::new(sub_relations), typeck_results: Some(self.typeck_results.borrow()), fallback_has_occurred: self.fallback_has_occurred.get(), normalize_fn_sig: Box::new(|fn_sig| { diff --git a/compiler/rustc_infer/src/infer/context.rs b/compiler/rustc_infer/src/infer/context.rs index 22d7ce79bb46..5b70eaf9cca8 100644 --- a/compiler/rustc_infer/src/infer/context.rs +++ b/compiler/rustc_infer/src/infer/context.rs @@ -125,6 +125,10 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { self.inner.borrow_mut().type_variables().equate(a, b); } + fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) { + self.inner.borrow_mut().type_variables().sub(a, b); + } + fn equate_int_vids_raw(&self, a: ty::IntVid, b: ty::IntVid) { self.inner.borrow_mut().int_unification_table().union(a, b); } diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index d25542dadd59..e78702b10ba3 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -733,6 +733,7 @@ impl<'tcx> InferCtxt<'tcx> { let r_b = self.shallow_resolve(predicate.skip_binder().b); match (r_a.kind(), r_b.kind()) { (&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => { + self.inner.borrow_mut().type_variables().sub(a_vid, b_vid); return Err((a_vid, b_vid)); } _ => {} @@ -1065,6 +1066,10 @@ impl<'tcx> InferCtxt<'tcx> { self.inner.borrow_mut().type_variables().root_var(var) } + pub fn sub_root_var(&self, var: ty::TyVid) -> ty::TyVid { + self.inner.borrow_mut().type_variables().sub_root_var(var) + } + pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid { self.inner.borrow_mut().const_unification_table().find(var).vid } diff --git a/compiler/rustc_infer/src/infer/relate/generalize.rs b/compiler/rustc_infer/src/infer/relate/generalize.rs index 210b8f37d883..cdcd6f6dbb6f 100644 --- a/compiler/rustc_infer/src/infer/relate/generalize.rs +++ b/compiler/rustc_infer/src/infer/relate/generalize.rs @@ -503,6 +503,10 @@ impl<'tcx> TypeRelation> for Generalizer<'_, 'tcx> { let origin = inner.type_variables().var_origin(vid); let new_var_id = inner.type_variables().new_var(self.for_universe, origin); + // Record that `vid` and `new_var_id` have to be subtypes + // of each other. This is currently only used for diagnostics. + // To see why, see the docs in the `type_variables` module. + inner.type_variables().sub(vid, new_var_id); // If we're in the new solver and create a new inference // variable inside of an alias we eagerly constrain that // inference variable to prevent unexpected ambiguity errors. diff --git a/compiler/rustc_infer/src/infer/snapshot/undo_log.rs b/compiler/rustc_infer/src/infer/snapshot/undo_log.rs index ba7d8f588e68..cb2c9d8ce250 100644 --- a/compiler/rustc_infer/src/infer/snapshot/undo_log.rs +++ b/compiler/rustc_infer/src/infer/snapshot/undo_log.rs @@ -18,7 +18,7 @@ pub struct Snapshot<'tcx> { #[derive(Clone)] pub(crate) enum UndoLog<'tcx> { OpaqueTypes(OpaqueTypeKey<'tcx>, Option>), - TypeVariables(sv::UndoLog>>), + TypeVariables(type_variable::UndoLog<'tcx>), ConstUnificationTable(sv::UndoLog>>), IntUnificationTable(sv::UndoLog>), FloatUnificationTable(sv::UndoLog>), @@ -44,7 +44,9 @@ macro_rules! impl_from { impl_from! { RegionConstraintCollector(region_constraints::UndoLog<'tcx>), + TypeVariables(type_variable::UndoLog<'tcx>), TypeVariables(sv::UndoLog>>), + TypeVariables(sv::UndoLog>), IntUnificationTable(sv::UndoLog>), FloatUnificationTable(sv::UndoLog>), diff --git a/compiler/rustc_infer/src/infer/type_variable.rs b/compiler/rustc_infer/src/infer/type_variable.rs index 2086483b94a7..b343baff3603 100644 --- a/compiler/rustc_infer/src/infer/type_variable.rs +++ b/compiler/rustc_infer/src/infer/type_variable.rs @@ -13,9 +13,33 @@ use tracing::debug; use crate::infer::InferCtxtUndoLogs; -impl<'tcx> Rollback>>> for TypeVariableStorage<'tcx> { - fn reverse(&mut self, undo: sv::UndoLog>>) { - self.eq_relations.reverse(undo) +/// Represents a single undo-able action that affects a type inference variable. +#[derive(Clone)] +pub(crate) enum UndoLog<'tcx> { + EqRelation(sv::UndoLog>>), + SubRelation(sv::UndoLog>), +} + +/// Convert from a specific kind of undo to the more general UndoLog +impl<'tcx> From>>> for UndoLog<'tcx> { + fn from(l: sv::UndoLog>>) -> Self { + UndoLog::EqRelation(l) + } +} + +/// Convert from a specific kind of undo to the more general UndoLog +impl<'tcx> From>> for UndoLog<'tcx> { + fn from(l: sv::UndoLog>) -> Self { + UndoLog::SubRelation(l) + } +} + +impl<'tcx> Rollback> for TypeVariableStorage<'tcx> { + fn reverse(&mut self, undo: UndoLog<'tcx>) { + match undo { + UndoLog::EqRelation(undo) => self.eq_relations.reverse(undo), + UndoLog::SubRelation(undo) => self.sub_relations.reverse(undo), + } } } @@ -27,6 +51,24 @@ pub(crate) struct TypeVariableStorage<'tcx> { /// constraint `?X == ?Y`. This table also stores, for each key, /// the known value. eq_relations: ut::UnificationTableStorage>, + + /// Only used by `-Znext-solver` and for diagnostics. + /// + /// When reporting ambiguity errors, we sometimes want to + /// treat all inference vars which are subtypes of each + /// others as if they are equal. For this case we compute + /// the transitive closure of our subtype obligations here. + /// + /// E.g. when encountering ambiguity errors, we want to suggest + /// specifying some method argument or to add a type annotation + /// to a local variable. Because subtyping cannot change the + /// shape of a type, it's fine if the cause of the ambiguity error + /// is only related to the suggested variable via subtyping. + /// + /// Even for something like `let x = returns_arg(); x.method();` the + /// type of `x` is only a supertype of the argument of `returns_arg`. We + /// still want to suggest specifying the type of the argument. + sub_relations: ut::UnificationTableStorage, } pub(crate) struct TypeVariableTable<'a, 'tcx> { @@ -109,6 +151,16 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { debug_assert!(self.probe(a).is_unknown()); debug_assert!(self.probe(b).is_unknown()); self.eq_relations().union(a, b); + self.sub_relations().union(a, b); + } + + /// Records that `a <: b`, depending on `dir`. + /// + /// Precondition: neither `a` nor `b` are known. + pub(crate) fn sub(&mut self, a: ty::TyVid, b: ty::TyVid) { + debug_assert!(self.probe(a).is_unknown()); + debug_assert!(self.probe(b).is_unknown()); + self.sub_relations().union(a, b); } /// Instantiates `vid` with the type `ty`. @@ -142,6 +194,10 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { origin: TypeVariableOrigin, ) -> ty::TyVid { let eq_key = self.eq_relations().new_key(TypeVariableValue::Unknown { universe }); + + let sub_key = self.sub_relations().new_key(()); + debug_assert_eq!(eq_key.vid, sub_key); + let index = self.storage.values.push(TypeVariableData { origin }); debug_assert_eq!(eq_key.vid, index); @@ -164,6 +220,18 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { self.eq_relations().find(vid).vid } + /// Returns the "root" variable of `vid` in the `sub_relations` + /// equivalence table. All type variables that have been are + /// related via equality or subtyping will yield the same root + /// variable (per the union-find algorithm), so `sub_root_var(a) + /// == sub_root_var(b)` implies that: + /// ```text + /// exists X. (a <: X || X <: a) && (b <: X || X <: b) + /// ``` + pub(crate) fn sub_root_var(&mut self, vid: ty::TyVid) -> ty::TyVid { + self.sub_relations().find(vid) + } + /// Retrieves the type to which `vid` has been instantiated, if /// any. pub(crate) fn probe(&mut self, vid: ty::TyVid) -> TypeVariableValue<'tcx> { @@ -181,6 +249,11 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { self.storage.eq_relations.with_log(self.undo_log) } + #[inline] + fn sub_relations(&mut self) -> super::UnificationTable<'_, 'tcx, ty::TyVid> { + self.storage.sub_relations.with_log(self.undo_log) + } + /// Returns a range of the type variables created during the snapshot. pub(crate) fn vars_since_snapshot( &mut self, diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs index 6dd554299a69..b072decffe0e 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs @@ -844,6 +844,10 @@ where && goal.param_env.visit_with(&mut visitor).is_continue() } + pub(super) fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) { + self.delegate.sub_ty_vids_raw(a, b) + } + #[instrument(level = "trace", skip(self, param_env), ret)] pub(super) fn eq>( &mut self, diff --git a/compiler/rustc_next_trait_solver/src/solve/mod.rs b/compiler/rustc_next_trait_solver/src/solve/mod.rs index c9f4fc649b52..639d103f9020 100644 --- a/compiler/rustc_next_trait_solver/src/solve/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/mod.rs @@ -120,11 +120,15 @@ where #[instrument(level = "trace", skip(self))] fn compute_subtype_goal(&mut self, goal: Goal>) -> QueryResult { - if goal.predicate.a.is_ty_var() && goal.predicate.b.is_ty_var() { - self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS) - } else { - self.sub(goal.param_env, goal.predicate.a, goal.predicate.b)?; - self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + match (goal.predicate.a.kind(), goal.predicate.b.kind()) { + (ty::Infer(ty::TyVar(a_vid)), ty::Infer(ty::TyVar(b_vid))) => { + self.sub_ty_vids_raw(a_vid, b_vid); + self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS) + } + _ => { + self.sub(goal.param_env, goal.predicate.a, goal.predicate.b)?; + self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + } } } diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs index fdd547448f00..00230f45a892 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs @@ -92,7 +92,6 @@ mod suggest; pub mod need_type_info; pub mod nice_region_error; pub mod region; -pub mod sub_relations; /// Makes a valid string literal from a string by escaping special characters (" and \), /// unless they are already escaped. diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/need_type_info.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/need_type_info.rs index de9a50f19623..72571a177e1e 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/infer/need_type_info.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/infer/need_type_info.rs @@ -945,7 +945,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> { use ty::{Infer, TyVar}; match (inner_ty.kind(), target_ty.kind()) { (&Infer(TyVar(a_vid)), &Infer(TyVar(b_vid))) => { - self.tecx.sub_relations.borrow_mut().unified(self.tecx, a_vid, b_vid) + self.tecx.sub_root_var(a_vid) == self.tecx.sub_root_var(b_vid) } _ => false, } diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/sub_relations.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/sub_relations.rs deleted file mode 100644 index ef26a8ff7b86..000000000000 --- a/compiler/rustc_trait_selection/src/error_reporting/infer/sub_relations.rs +++ /dev/null @@ -1,81 +0,0 @@ -use rustc_data_structures::fx::FxHashMap; -use rustc_data_structures::undo_log::NoUndo; -use rustc_data_structures::unify as ut; -use rustc_middle::ty; - -use crate::infer::InferCtxt; - -#[derive(Debug, Copy, Clone, PartialEq)] -struct SubId(u32); -impl ut::UnifyKey for SubId { - type Value = (); - #[inline] - fn index(&self) -> u32 { - self.0 - } - #[inline] - fn from_index(i: u32) -> SubId { - SubId(i) - } - fn tag() -> &'static str { - "SubId" - } -} - -/// When reporting ambiguity errors, we sometimes want to -/// treat all inference vars which are subtypes of each -/// others as if they are equal. For this case we compute -/// the transitive closure of our subtype obligations here. -/// -/// E.g. when encountering ambiguity errors, we want to suggest -/// specifying some method argument or to add a type annotation -/// to a local variable. Because subtyping cannot change the -/// shape of a type, it's fine if the cause of the ambiguity error -/// is only related to the suggested variable via subtyping. -/// -/// Even for something like `let x = returns_arg(); x.method();` the -/// type of `x` is only a supertype of the argument of `returns_arg`. We -/// still want to suggest specifying the type of the argument. -#[derive(Default)] -pub struct SubRelations { - map: FxHashMap, - table: ut::UnificationTableStorage, -} - -impl SubRelations { - fn get_id<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, vid: ty::TyVid) -> SubId { - let root_vid = infcx.root_var(vid); - *self.map.entry(root_vid).or_insert_with(|| self.table.with_log(&mut NoUndo).new_key(())) - } - - pub fn add_constraints<'tcx>( - &mut self, - infcx: &InferCtxt<'tcx>, - obls: impl IntoIterator>, - ) { - for p in obls { - let (a, b) = match p.kind().skip_binder() { - ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => { - (a, b) - } - ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b), - _ => continue, - }; - - match (a.kind(), b.kind()) { - (&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => { - let a = self.get_id(infcx, a_vid); - let b = self.get_id(infcx, b_vid); - self.table.with_log(&mut NoUndo).unify_var_var(a, b).unwrap(); - } - _ => continue, - } - } - } - - pub fn unified<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, a: ty::TyVid, b: ty::TyVid) -> bool { - let a = self.get_id(infcx, a); - let b = self.get_id(infcx, b); - self.table.with_log(&mut NoUndo).unioned(a, b) - } -} diff --git a/compiler/rustc_trait_selection/src/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/mod.rs index 82695688ae89..cce20b05c79a 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/mod.rs @@ -7,8 +7,6 @@ use rustc_macros::extension; use rustc_middle::bug; use rustc_middle::ty::{self, Ty}; -use crate::error_reporting::infer::sub_relations; - pub mod infer; pub mod traits; @@ -21,7 +19,6 @@ pub mod traits; /// methods which should not be used during the happy path. pub struct TypeErrCtxt<'a, 'tcx> { pub infcx: &'a InferCtxt<'tcx>, - pub sub_relations: std::cell::RefCell, pub typeck_results: Option>>, pub fallback_has_occurred: bool, @@ -38,7 +35,6 @@ impl<'tcx> InferCtxt<'tcx> { fn err_ctxt(&self) -> TypeErrCtxt<'_, 'tcx> { TypeErrCtxt { infcx: self, - sub_relations: Default::default(), typeck_results: None, fallback_has_occurred: false, normalize_fn_sig: Box::new(|fn_sig| fn_sig), diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs index 78f9287b407b..9d2d64e4398f 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs @@ -141,10 +141,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> { &self, mut errors: Vec>, ) -> ErrorGuaranteed { - self.sub_relations - .borrow_mut() - .add_constraints(self, errors.iter().map(|e| e.obligation.predicate)); - #[derive(Debug)] struct ErrorDescriptor<'tcx> { goal: Goal<'tcx, ty::Predicate<'tcx>>, diff --git a/compiler/rustc_type_ir/src/infer_ctxt.rs b/compiler/rustc_type_ir/src/infer_ctxt.rs index 8fa56c359996..3974d160e26e 100644 --- a/compiler/rustc_type_ir/src/infer_ctxt.rs +++ b/compiler/rustc_type_ir/src/infer_ctxt.rs @@ -186,6 +186,7 @@ pub trait InferCtxtLike: Sized { ) -> U; fn equate_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid); + fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid); fn equate_int_vids_raw(&self, a: ty::IntVid, b: ty::IntVid); fn equate_float_vids_raw(&self, a: ty::FloatVid, b: ty::FloatVid); fn equate_const_vids_raw(&self, a: ty::ConstVid, b: ty::ConstVid); From f39776bf53b33350ff38b86321fc4856c8aa699d Mon Sep 17 00:00:00 2001 From: lcnr Date: Mon, 28 Apr 2025 17:21:36 +0000 Subject: [PATCH 2/5] inline `CanonicalTyVarKind` --- .../src/infer/canonical/canonicalizer.rs | 23 ++------ .../rustc_infer/src/infer/canonical/mod.rs | 15 +---- compiler/rustc_middle/src/infer/canonical.rs | 2 +- .../src/canonicalizer.rs | 12 ++-- compiler/rustc_type_ir/src/canonical.rs | 59 ++++++++----------- 5 files changed, 40 insertions(+), 71 deletions(-) diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index a1a0926cd818..925098ee4072 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -17,8 +17,7 @@ use tracing::debug; use crate::infer::InferCtxt; use crate::infer::canonical::{ - Canonical, CanonicalQueryInput, CanonicalTyVarKind, CanonicalVarInfo, CanonicalVarKind, - OriginalQueryValues, + Canonical, CanonicalQueryInput, CanonicalVarInfo, CanonicalVarKind, OriginalQueryValues, }; impl<'tcx> InferCtxt<'tcx> { @@ -368,9 +367,7 @@ impl<'cx, 'tcx> TypeFolder> for Canonicalizer<'cx, 'tcx> { ui = ty::UniverseIndex::ROOT; } self.canonicalize_ty_var( - CanonicalVarInfo { - kind: CanonicalVarKind::Ty(CanonicalTyVarKind::General(ui)), - }, + CanonicalVarInfo { kind: CanonicalVarKind::Ty(ui) }, t, ) } @@ -382,10 +379,7 @@ impl<'cx, 'tcx> TypeFolder> for Canonicalizer<'cx, 'tcx> { if nt != t { return self.fold_ty(nt); } else { - self.canonicalize_ty_var( - CanonicalVarInfo { kind: CanonicalVarKind::Ty(CanonicalTyVarKind::Int) }, - t, - ) + self.canonicalize_ty_var(CanonicalVarInfo { kind: CanonicalVarKind::Int }, t) } } ty::Infer(ty::FloatVar(vid)) => { @@ -393,10 +387,7 @@ impl<'cx, 'tcx> TypeFolder> for Canonicalizer<'cx, 'tcx> { if nt != t { return self.fold_ty(nt); } else { - self.canonicalize_ty_var( - CanonicalVarInfo { kind: CanonicalVarKind::Ty(CanonicalTyVarKind::Float) }, - t, - ) + self.canonicalize_ty_var(CanonicalVarInfo { kind: CanonicalVarKind::Float }, t) } } @@ -690,12 +681,10 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { .iter() .map(|v| CanonicalVarInfo { kind: match v.kind { - CanonicalVarKind::Ty(CanonicalTyVarKind::Int | CanonicalTyVarKind::Float) => { + CanonicalVarKind::Int | CanonicalVarKind::Float => { return *v; } - CanonicalVarKind::Ty(CanonicalTyVarKind::General(u)) => { - CanonicalVarKind::Ty(CanonicalTyVarKind::General(reverse_universe_map[&u])) - } + CanonicalVarKind::Ty(u) => CanonicalVarKind::Ty(reverse_universe_map[&u]), CanonicalVarKind::Region(u) => { CanonicalVarKind::Region(reverse_universe_map[&u]) } diff --git a/compiler/rustc_infer/src/infer/canonical/mod.rs b/compiler/rustc_infer/src/infer/canonical/mod.rs index 3be07dbe208f..08df0f41bd98 100644 --- a/compiler/rustc_infer/src/infer/canonical/mod.rs +++ b/compiler/rustc_infer/src/infer/canonical/mod.rs @@ -108,18 +108,9 @@ impl<'tcx> InferCtxt<'tcx> { universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, ) -> GenericArg<'tcx> { match cv_info.kind { - CanonicalVarKind::Ty(ty_kind) => { - let ty = match ty_kind { - CanonicalTyVarKind::General(ui) => { - self.next_ty_var_in_universe(span, universe_map(ui)) - } - - CanonicalTyVarKind::Int => self.next_int_var(), - - CanonicalTyVarKind::Float => self.next_float_var(), - }; - ty.into() - } + CanonicalVarKind::Ty(ui) => self.next_ty_var_in_universe(span, universe_map(ui)).into(), + CanonicalVarKind::Int => self.next_int_var().into(), + CanonicalVarKind::Float => self.next_float_var().into(), CanonicalVarKind::PlaceholderTy(ty::PlaceholderType { universe, bound }) => { let universe_mapped = universe_map(universe); diff --git a/compiler/rustc_middle/src/infer/canonical.rs b/compiler/rustc_middle/src/infer/canonical.rs index 5b8603744961..15cc01ba6a01 100644 --- a/compiler/rustc_middle/src/infer/canonical.rs +++ b/compiler/rustc_middle/src/infer/canonical.rs @@ -27,7 +27,7 @@ use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::sync::Lock; use rustc_macros::{HashStable, TypeFoldable, TypeVisitable}; pub use rustc_type_ir as ir; -pub use rustc_type_ir::{CanonicalTyVarKind, CanonicalVarKind}; +pub use rustc_type_ir::CanonicalVarKind; use smallvec::SmallVec; use crate::mir::ConstraintCategory; diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs index bbb4a162027d..bac5767c457f 100644 --- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs +++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs @@ -4,8 +4,8 @@ use rustc_type_ir::data_structures::{HashMap, ensure_sufficient_stack}; use rustc_type_ir::inherent::*; use rustc_type_ir::solve::{Goal, QueryInput}; use rustc_type_ir::{ - self as ty, Canonical, CanonicalTyVarKind, CanonicalVarInfo, CanonicalVarKind, InferCtxtLike, - Interner, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, + self as ty, Canonical, CanonicalVarInfo, CanonicalVarKind, InferCtxtLike, Interner, + TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, }; use crate::delegate::SolverDelegate; @@ -323,11 +323,11 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { "ty vid should have been resolved fully before canonicalization" ); - CanonicalVarKind::Ty(CanonicalTyVarKind::General( + CanonicalVarKind::Ty( self.delegate .universe_of_ty(vid) .unwrap_or_else(|| panic!("ty var should have been resolved: {t:?}")), - )) + ) } ty::IntVar(vid) => { assert_eq!( @@ -335,7 +335,7 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { t, "ty vid should have been resolved fully before canonicalization" ); - CanonicalVarKind::Ty(CanonicalTyVarKind::Int) + CanonicalVarKind::Int } ty::FloatVar(vid) => { assert_eq!( @@ -343,7 +343,7 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { t, "ty vid should have been resolved fully before canonicalization" ); - CanonicalVarKind::Ty(CanonicalTyVarKind::Float) + CanonicalVarKind::Float } ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => { panic!("fresh vars not expected in canonicalization") diff --git a/compiler/rustc_type_ir/src/canonical.rs b/compiler/rustc_type_ir/src/canonical.rs index 67b67df4b281..9bf16e1a9414 100644 --- a/compiler/rustc_type_ir/src/canonical.rs +++ b/compiler/rustc_type_ir/src/canonical.rs @@ -110,7 +110,7 @@ impl CanonicalVarInfo { pub fn is_existential(&self) -> bool { match self.kind { - CanonicalVarKind::Ty(_) => true, + CanonicalVarKind::Ty(_) | CanonicalVarKind::Int | CanonicalVarKind::Float => true, CanonicalVarKind::PlaceholderTy(_) => false, CanonicalVarKind::Region(_) => true, CanonicalVarKind::PlaceholderRegion(..) => false, @@ -123,6 +123,8 @@ impl CanonicalVarInfo { match self.kind { CanonicalVarKind::Region(_) | CanonicalVarKind::PlaceholderRegion(_) => true, CanonicalVarKind::Ty(_) + | CanonicalVarKind::Int + | CanonicalVarKind::Float | CanonicalVarKind::PlaceholderTy(_) | CanonicalVarKind::Const(_) | CanonicalVarKind::PlaceholderConst(_) => false, @@ -131,7 +133,11 @@ impl CanonicalVarInfo { pub fn expect_placeholder_index(self) -> usize { match self.kind { - CanonicalVarKind::Ty(_) | CanonicalVarKind::Region(_) | CanonicalVarKind::Const(_) => { + CanonicalVarKind::Ty(_) + | CanonicalVarKind::Int + | CanonicalVarKind::Float + | CanonicalVarKind::Region(_) + | CanonicalVarKind::Const(_) => { panic!("expected placeholder: {self:?}") } @@ -151,8 +157,14 @@ impl CanonicalVarInfo { derive(Decodable_NoContext, Encodable_NoContext, HashStable_NoContext) )] pub enum CanonicalVarKind { - /// Some kind of type inference variable. - Ty(CanonicalTyVarKind), + /// A general type variable `?T` that can be unified with arbitrary types. + Ty(UniverseIndex), + + /// Integral type variable `?I` (that can only be unified with integral types). + Int, + + /// Floating-point type variable `?F` (that can only be unified with float types). + Float, /// A "placeholder" that represents "any type". PlaceholderTy(I::PlaceholderTy), @@ -175,15 +187,13 @@ pub enum CanonicalVarKind { impl CanonicalVarKind { pub fn universe(self) -> UniverseIndex { match self { - CanonicalVarKind::Ty(CanonicalTyVarKind::General(ui)) => ui, + CanonicalVarKind::Ty(ui) => ui, CanonicalVarKind::Region(ui) => ui, CanonicalVarKind::Const(ui) => ui, CanonicalVarKind::PlaceholderTy(placeholder) => placeholder.universe(), CanonicalVarKind::PlaceholderRegion(placeholder) => placeholder.universe(), CanonicalVarKind::PlaceholderConst(placeholder) => placeholder.universe(), - CanonicalVarKind::Ty(CanonicalTyVarKind::Float | CanonicalTyVarKind::Int) => { - UniverseIndex::ROOT - } + CanonicalVarKind::Int | CanonicalVarKind::Float => UniverseIndex::ROOT, } } @@ -193,9 +203,7 @@ impl CanonicalVarKind { /// the updated universe is not the root. pub fn with_updated_universe(self, ui: UniverseIndex) -> CanonicalVarKind { match self { - CanonicalVarKind::Ty(CanonicalTyVarKind::General(_)) => { - CanonicalVarKind::Ty(CanonicalTyVarKind::General(ui)) - } + CanonicalVarKind::Ty(_) => CanonicalVarKind::Ty(ui), CanonicalVarKind::Region(_) => CanonicalVarKind::Region(ui), CanonicalVarKind::Const(_) => CanonicalVarKind::Const(ui), @@ -208,7 +216,7 @@ impl CanonicalVarKind { CanonicalVarKind::PlaceholderConst(placeholder) => { CanonicalVarKind::PlaceholderConst(placeholder.with_updated_universe(ui)) } - CanonicalVarKind::Ty(CanonicalTyVarKind::Int | CanonicalTyVarKind::Float) => { + CanonicalVarKind::Int | CanonicalVarKind::Float => { assert_eq!(ui, UniverseIndex::ROOT); self } @@ -216,28 +224,6 @@ impl CanonicalVarKind { } } -/// Rust actually has more than one category of type variables; -/// notably, the type variables we create for literals (e.g., 22 or -/// 22.) can only be instantiated with integral/float types (e.g., -/// usize or f32). In order to faithfully reproduce a type, we need to -/// know what set of types a given type variable can be unified with. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -#[derive(TypeVisitable_Generic, TypeFoldable_Generic)] -#[cfg_attr( - feature = "nightly", - derive(Decodable_NoContext, Encodable_NoContext, HashStable_NoContext) -)] -pub enum CanonicalTyVarKind { - /// General type variable `?T` that can be unified with arbitrary types. - General(UniverseIndex), - - /// Integral type variable `?I` (that can only be unified with integral types). - Int, - - /// Floating-point type variable `?F` (that can only be unified with float types). - Float, -} - /// A set of values corresponding to the canonical variables from some /// `Canonical`. You can give these values to /// `canonical_value.instantiate` to instantiate them into the canonical @@ -311,7 +297,10 @@ impl CanonicalVarValues { var_values: cx.mk_args_from_iter(infos.iter().enumerate().map( |(i, info)| -> I::GenericArg { match info.kind { - CanonicalVarKind::Ty(_) | CanonicalVarKind::PlaceholderTy(_) => { + CanonicalVarKind::Ty(_) + | CanonicalVarKind::Int + | CanonicalVarKind::Float + | CanonicalVarKind::PlaceholderTy(_) => { Ty::new_anon_bound(cx, ty::INNERMOST, ty::BoundVar::from_usize(i)) .into() } From 8eb514adebbc67f581d96d7541a0bf689a62b574 Mon Sep 17 00:00:00 2001 From: lcnr Date: Tue, 6 May 2025 20:53:33 +0000 Subject: [PATCH 3/5] track relevant `sub_relations` in canonical queries This allows canonical queries to also rely on them in the future. It also means it would now be sound to rely on `sub_relations` in the generalizer. --- .../src/infer/canonical/canonicalizer.rs | 19 ++++- .../rustc_infer/src/infer/canonical/mod.rs | 25 ++++-- .../src/infer/canonical/query_response.rs | 74 ++++++++++++------ compiler/rustc_infer/src/infer/context.rs | 3 + .../src/canonicalizer.rs | 22 ++++-- .../rustc_next_trait_solver/src/delegate.rs | 1 + .../src/solve/eval_ctxt/canonical.rs | 77 ++++++++++++------- .../src/solve/delegate.rs | 3 +- compiler/rustc_type_ir/src/canonical.rs | 20 +++-- compiler/rustc_type_ir/src/infer_ctxt.rs | 1 + 10 files changed, 169 insertions(+), 76 deletions(-) diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index 925098ee4072..a76bd93a12cf 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -6,6 +6,7 @@ //! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::sso::SsoHashMap; use rustc_index::Idx; use rustc_middle::bug; use rustc_middle::ty::{ @@ -298,6 +299,7 @@ struct Canonicalizer<'cx, 'tcx> { // Note that indices is only used once `var_values` is big enough to be // heap-allocated. indices: FxHashMap, BoundVar>, + sub_root_lookup_table: SsoHashMap, canonicalize_mode: &'cx dyn CanonicalizeMode, needs_canonical_flags: TypeFlags, @@ -366,8 +368,11 @@ impl<'cx, 'tcx> TypeFolder> for Canonicalizer<'cx, 'tcx> { // FIXME: perf problem described in #55921. ui = ty::UniverseIndex::ROOT; } + let sub_root = self.get_or_insert_sub_root(vid); self.canonicalize_ty_var( - CanonicalVarInfo { kind: CanonicalVarKind::Ty(ui) }, + CanonicalVarInfo { + kind: CanonicalVarKind::Ty { universe: ui, sub_root }, + }, t, ) } @@ -567,6 +572,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { variables: SmallVec::from_slice(base.variables), query_state, indices: FxHashMap::default(), + sub_root_lookup_table: Default::default(), binder_index: ty::INNERMOST, }; if canonicalizer.query_state.var_values.spilled() { @@ -661,6 +667,13 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { } } + fn get_or_insert_sub_root(&mut self, vid: ty::TyVid) -> ty::BoundVar { + let root_vid = self.infcx.unwrap().sub_root_var(vid); + let idx = + *self.sub_root_lookup_table.entry(root_vid).or_insert_with(|| self.variables.len()); + ty::BoundVar::from(idx) + } + /// Replaces the universe indexes used in `var_values` with their index in /// `query_state.universe_map`. This minimizes the maximum universe used in /// the canonicalized value. @@ -684,7 +697,9 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { CanonicalVarKind::Int | CanonicalVarKind::Float => { return *v; } - CanonicalVarKind::Ty(u) => CanonicalVarKind::Ty(reverse_universe_map[&u]), + CanonicalVarKind::Ty { universe, sub_root } => { + CanonicalVarKind::Ty { universe: reverse_universe_map[&universe], sub_root } + } CanonicalVarKind::Region(u) => { CanonicalVarKind::Region(reverse_universe_map[&u]) } diff --git a/compiler/rustc_infer/src/infer/canonical/mod.rs b/compiler/rustc_infer/src/infer/canonical/mod.rs index 08df0f41bd98..f4c607fe6a3b 100644 --- a/compiler/rustc_infer/src/infer/canonical/mod.rs +++ b/compiler/rustc_infer/src/infer/canonical/mod.rs @@ -84,13 +84,12 @@ impl<'tcx> InferCtxt<'tcx> { variables: &List>, universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, ) -> CanonicalVarValues<'tcx> { - CanonicalVarValues { - var_values: self.tcx.mk_args_from_iter( - variables - .iter() - .map(|info| self.instantiate_canonical_var(span, info, &universe_map)), - ), + let mut var_values = Vec::new(); + for info in variables.iter() { + let value = self.instantiate_canonical_var(span, info, &var_values, &universe_map); + var_values.push(value); } + CanonicalVarValues { var_values: self.tcx.mk_args(&var_values) } } /// Given the "info" about a canonical variable, creates a fresh @@ -105,10 +104,22 @@ impl<'tcx> InferCtxt<'tcx> { &self, span: Span, cv_info: CanonicalVarInfo<'tcx>, + previous_var_values: &[GenericArg<'tcx>], universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, ) -> GenericArg<'tcx> { match cv_info.kind { - CanonicalVarKind::Ty(ui) => self.next_ty_var_in_universe(span, universe_map(ui)).into(), + CanonicalVarKind::Ty { universe, sub_root } => { + let vid = self.next_ty_var_id_in_universe(span, universe_map(universe)); + if let Some(prev) = previous_var_values.get(sub_root.as_usize()) { + // We cannot simply assume that previous `var_values` get instantiated + // with inference variables as we may reuse the generic arguments from the + // input which may have gotten constrained after we've canonicalized it. + if let &ty::Infer(ty::TyVar(sub_root)) = prev.expect_ty().kind() { + self.inner.borrow_mut().type_variables().sub(vid, sub_root); + } + } + Ty::new_var(self.tcx, vid).into() + } CanonicalVarKind::Int => self.next_int_var().into(), CanonicalVarKind::Float => self.next_float_var().into(), diff --git a/compiler/rustc_infer/src/infer/canonical/query_response.rs b/compiler/rustc_infer/src/infer/canonical/query_response.rs index 5220071c5005..6ce4e601aa82 100644 --- a/compiler/rustc_infer/src/infer/canonical/query_response.rs +++ b/compiler/rustc_infer/src/infer/canonical/query_response.rs @@ -13,7 +13,9 @@ use std::iter; use rustc_index::{Idx, IndexVec}; use rustc_middle::arena::ArenaAllocatable; use rustc_middle::mir::ConstraintCategory; -use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt, TypeFoldable}; +use rustc_middle::ty::{ + self, BoundVar, CanonicalVarKind, GenericArg, GenericArgKind, Ty, TyCtxt, TypeFoldable, +}; use rustc_middle::{bug, span_bug}; use tracing::{debug, instrument}; @@ -455,32 +457,54 @@ impl<'tcx> InferCtxt<'tcx> { // Create result arguments: if we found a value for a // given variable in the loop above, use that. Otherwise, use // a fresh inference variable. - let result_args = CanonicalVarValues { - var_values: self.tcx.mk_args_from_iter( - query_response.variables.iter().enumerate().map(|(index, info)| { - if info.universe() != ty::UniverseIndex::ROOT { - // A variable from inside a binder of the query. While ideally these shouldn't - // exist at all, we have to deal with them for now. - self.instantiate_canonical_var(cause.span, info, |u| { - universe_map[u.as_usize()] - }) - } else if info.is_existential() { - match opt_values[BoundVar::new(index)] { - Some(k) => k, - None => self.instantiate_canonical_var(cause.span, info, |u| { - universe_map[u.as_usize()] - }), + let mut var_values = Vec::new(); + for (index, info) in query_response.variables.iter().enumerate() { + let value = if info.universe() != ty::UniverseIndex::ROOT { + // A variable from inside a binder of the query. While ideally these shouldn't + // exist at all, we have to deal with them for now. + self.instantiate_canonical_var(cause.span, info, &var_values, |u| { + universe_map[u.as_usize()] + }) + } else if info.is_existential() { + // As an optimization we sometimes avoid creating a new inference variable here. + // We need to still make sure to register any subtype relations returned by the + // query. + match opt_values[BoundVar::new(index)] { + Some(v) => { + if let CanonicalVarKind::Ty { universe: _, sub_root } = info.kind { + if let Some(prev) = var_values.get(sub_root.as_usize()) { + // We cannot simply assume that previous `var_values` + // are inference variables, see the comment in + // `instantiate_canonical_var`. + let v = self.shallow_resolve(v.expect_ty()); + let prev = self.shallow_resolve(prev.expect_ty()); + match (v.kind(), prev.kind()) { + ( + &ty::Infer(ty::TyVar(vid)), + &ty::Infer(ty::TyVar(sub_root)), + ) => { + self.inner.borrow_mut().type_variables().sub(vid, sub_root) + } + _ => {} + } + } } - } else { - // For placeholders which were already part of the input, we simply map this - // universal bound variable back the placeholder of the input. - opt_values[BoundVar::new(index)].expect( - "expected placeholder to be unified with itself during response", - ) + v } - }), - ), - }; + None => self.instantiate_canonical_var(cause.span, info, &var_values, |u| { + universe_map[u.as_usize()] + }), + } + } else { + // For placeholders which were already part of the input, we simply map this + // universal bound variable back the placeholder of the input. + opt_values[BoundVar::new(index)] + .expect("expected placeholder to be unified with itself during response") + }; + var_values.push(value) + } + + let result_args = CanonicalVarValues { var_values: self.tcx.mk_args(&var_values) }; let mut obligations = PredicateObligations::new(); diff --git a/compiler/rustc_infer/src/infer/context.rs b/compiler/rustc_infer/src/infer/context.rs index 5b70eaf9cca8..ae773bab4680 100644 --- a/compiler/rustc_infer/src/infer/context.rs +++ b/compiler/rustc_infer/src/infer/context.rs @@ -55,6 +55,9 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> { fn root_ty_var(&self, var: ty::TyVid) -> ty::TyVid { self.root_var(var) } + fn sub_root_ty_var(&self, var: ty::TyVid) -> ty::TyVid { + self.sub_root_var(var) + } fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid { self.root_const_var(var) diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs index bac5767c457f..5097f9ec545a 100644 --- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs +++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs @@ -52,6 +52,7 @@ pub struct Canonicalizer<'a, D: SolverDelegate, I: Interner> { variables: &'a mut Vec, primitive_var_infos: Vec>, variable_lookup_table: HashMap, + sub_root_lookup_table: HashMap, binder_index: ty::DebruijnIndex, /// We only use the debruijn index during lookup. We don't need to @@ -73,6 +74,7 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { variables, variable_lookup_table: Default::default(), + sub_root_lookup_table: Default::default(), primitive_var_infos: Vec::new(), binder_index: ty::INNERMOST, @@ -106,6 +108,7 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { variables, variable_lookup_table: Default::default(), + sub_root_lookup_table: Default::default(), primitive_var_infos: Vec::new(), binder_index: ty::INNERMOST, @@ -123,6 +126,7 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { // We're able to reuse the `variable_lookup_table` as whether or not // it already contains an entry for `'static` does not matter. variable_lookup_table: env_canonicalizer.variable_lookup_table, + sub_root_lookup_table: Default::default(), primitive_var_infos: env_canonicalizer.primitive_var_infos, binder_index: ty::INNERMOST, @@ -177,6 +181,13 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { ty::BoundVar::from(idx) } + fn get_or_insert_sub_root(&mut self, vid: ty::TyVid) -> ty::BoundVar { + let root_vid = self.delegate.sub_root_ty_var(vid); + let idx = + *self.sub_root_lookup_table.entry(root_vid).or_insert_with(|| self.variables.len()); + ty::BoundVar::from(idx) + } + fn finalize(self) -> (ty::UniverseIndex, I::CanonicalVars) { let mut var_infos = self.primitive_var_infos; // See the rustc-dev-guide section about how we deal with universes @@ -323,11 +334,12 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { "ty vid should have been resolved fully before canonicalization" ); - CanonicalVarKind::Ty( - self.delegate - .universe_of_ty(vid) - .unwrap_or_else(|| panic!("ty var should have been resolved: {t:?}")), - ) + let universe = self + .delegate + .universe_of_ty(vid) + .unwrap_or_else(|| panic!("ty var should have been resolved: {t:?}")); + let sub_root = self.get_or_insert_sub_root(vid); + CanonicalVarKind::Ty { universe, sub_root } } ty::IntVar(vid) => { assert_eq!( diff --git a/compiler/rustc_next_trait_solver/src/delegate.rs b/compiler/rustc_next_trait_solver/src/delegate.rs index 25493970a0ce..fe229056a93c 100644 --- a/compiler/rustc_next_trait_solver/src/delegate.rs +++ b/compiler/rustc_next_trait_solver/src/delegate.rs @@ -59,6 +59,7 @@ pub trait SolverDelegate: Deref + Sized { &self, cv_info: ty::CanonicalVarInfo, span: ::Span, + var_values: &[::GenericArg], universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, ) -> ::GenericArg; diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs index dded84f67686..30780167911c 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs @@ -15,7 +15,8 @@ use rustc_index::IndexVec; use rustc_type_ir::inherent::*; use rustc_type_ir::relate::solver_relating::RelateExt; use rustc_type_ir::{ - self as ty, Canonical, CanonicalVarValues, InferCtxtLike, Interner, TypeFoldable, + self as ty, Canonical, CanonicalVarKind, CanonicalVarValues, InferCtxtLike, Interner, + TypeFoldable, }; use tracing::{debug, instrument, trace}; @@ -354,37 +355,55 @@ where } } - let var_values = delegate.cx().mk_args_from_iter( - response.variables.iter().enumerate().map(|(index, info)| { - if info.universe() != ty::UniverseIndex::ROOT { - // A variable from inside a binder of the query. While ideally these shouldn't - // exist at all (see the FIXME at the start of this method), we have to deal with - // them for now. - delegate.instantiate_canonical_var_with_infer(info, span, |idx| { - prev_universe + idx.index() - }) - } else if info.is_existential() { - // As an optimization we sometimes avoid creating a new inference variable here. - // - // All new inference variables we create start out in the current universe of the caller. - // This is conceptually wrong as these inference variables would be able to name - // more placeholders then they should be able to. However the inference variables have - // to "come from somewhere", so by equating them with the original values of the caller - // later on, we pull them down into their correct universe again. - if let Some(v) = opt_values[ty::BoundVar::from_usize(index)] { - v - } else { - delegate.instantiate_canonical_var_with_infer(info, span, |_| prev_universe) + let mut var_values = Vec::new(); + for (index, info) in response.variables.iter().enumerate() { + let value = if info.universe() != ty::UniverseIndex::ROOT { + // A variable from inside a binder of the query. While ideally these shouldn't + // exist at all (see the FIXME at the start of this method), we have to deal with + // them for now. + delegate.instantiate_canonical_var_with_infer(info, span, &var_values, |idx| { + prev_universe + idx.index() + }) + } else if info.is_existential() { + // As an optimization we sometimes avoid creating a new inference variable here. + // We need to still make sure to register any subtype relations returned by the + // query. + if let Some(v) = opt_values[ty::BoundVar::from_usize(index)] { + if let CanonicalVarKind::Ty { universe: _, sub_root } = info.kind { + if let Some(prev) = var_values.get(sub_root.as_usize()) { + // We cannot simply assume that previous `var_values` + // are inference variables, see the comment in + // `instantiate_canonical_var`. + let v = delegate.shallow_resolve(v.expect_ty()); + let prev = delegate.shallow_resolve(prev.expect_ty()); + match (v.kind(), prev.kind()) { + (ty::Infer(ty::TyVar(vid)), ty::Infer(ty::TyVar(sub_root))) => { + delegate.sub_ty_vids_raw(vid, sub_root) + } + _ => {} + } + } } + v } else { - // For placeholders which were already part of the input, we simply map this - // universal bound variable back the placeholder of the input. - original_values[info.expect_placeholder_index()] + // All new inference variables we create start out in the current universe + // of the caller. This is conceptually wrong as these inference variables + // would be able to name more placeholders then they should be able to. + // However the inference variables have to "come from somewhere", so by + // equating them with the original values of the caller later on, we pull + // them down into their correct universe again. + delegate.instantiate_canonical_var_with_infer(info, span, &var_values, |_| { + prev_universe + }) } - }), - ); - - CanonicalVarValues { var_values } + } else { + // For placeholders which were already part of the input, we simply map this + // universal bound variable back the placeholder of the input. + original_values[info.expect_placeholder_index()] + }; + var_values.push(value) + } + CanonicalVarValues { var_values: delegate.cx().mk_args(&var_values) } } /// Unify the `original_values` with the `var_values` returned by the canonical query.. diff --git a/compiler/rustc_trait_selection/src/solve/delegate.rs b/compiler/rustc_trait_selection/src/solve/delegate.rs index 908c058aabec..535bd1e3f9fd 100644 --- a/compiler/rustc_trait_selection/src/solve/delegate.rs +++ b/compiler/rustc_trait_selection/src/solve/delegate.rs @@ -148,9 +148,10 @@ impl<'tcx> rustc_next_trait_solver::delegate::SolverDelegate for SolverDelegate< &self, cv_info: CanonicalVarInfo<'tcx>, span: Span, + var_values: &[ty::GenericArg<'tcx>], universe_map: impl Fn(ty::UniverseIndex) -> ty::UniverseIndex, ) -> ty::GenericArg<'tcx> { - self.0.instantiate_canonical_var(span, cv_info, universe_map) + self.0.instantiate_canonical_var(span, cv_info, var_values, universe_map) } fn register_hidden_type_in_storage( diff --git a/compiler/rustc_type_ir/src/canonical.rs b/compiler/rustc_type_ir/src/canonical.rs index 9bf16e1a9414..b65547ba04b8 100644 --- a/compiler/rustc_type_ir/src/canonical.rs +++ b/compiler/rustc_type_ir/src/canonical.rs @@ -110,7 +110,7 @@ impl CanonicalVarInfo { pub fn is_existential(&self) -> bool { match self.kind { - CanonicalVarKind::Ty(_) | CanonicalVarKind::Int | CanonicalVarKind::Float => true, + CanonicalVarKind::Ty { .. } | CanonicalVarKind::Int | CanonicalVarKind::Float => true, CanonicalVarKind::PlaceholderTy(_) => false, CanonicalVarKind::Region(_) => true, CanonicalVarKind::PlaceholderRegion(..) => false, @@ -122,7 +122,7 @@ impl CanonicalVarInfo { pub fn is_region(&self) -> bool { match self.kind { CanonicalVarKind::Region(_) | CanonicalVarKind::PlaceholderRegion(_) => true, - CanonicalVarKind::Ty(_) + CanonicalVarKind::Ty { .. } | CanonicalVarKind::Int | CanonicalVarKind::Float | CanonicalVarKind::PlaceholderTy(_) @@ -133,7 +133,7 @@ impl CanonicalVarInfo { pub fn expect_placeholder_index(self) -> usize { match self.kind { - CanonicalVarKind::Ty(_) + CanonicalVarKind::Ty { .. } | CanonicalVarKind::Int | CanonicalVarKind::Float | CanonicalVarKind::Region(_) @@ -158,7 +158,11 @@ impl CanonicalVarInfo { )] pub enum CanonicalVarKind { /// A general type variable `?T` that can be unified with arbitrary types. - Ty(UniverseIndex), + /// + /// We also store the index of the first type variable which is sub-unified + /// with this one. If there is no inference variable related to this one, + /// its `sub_root` just points to itself. + Ty { universe: UniverseIndex, sub_root: ty::BoundVar }, /// Integral type variable `?I` (that can only be unified with integral types). Int, @@ -187,7 +191,7 @@ pub enum CanonicalVarKind { impl CanonicalVarKind { pub fn universe(self) -> UniverseIndex { match self { - CanonicalVarKind::Ty(ui) => ui, + CanonicalVarKind::Ty { universe, sub_root: _ } => universe, CanonicalVarKind::Region(ui) => ui, CanonicalVarKind::Const(ui) => ui, CanonicalVarKind::PlaceholderTy(placeholder) => placeholder.universe(), @@ -203,7 +207,9 @@ impl CanonicalVarKind { /// the updated universe is not the root. pub fn with_updated_universe(self, ui: UniverseIndex) -> CanonicalVarKind { match self { - CanonicalVarKind::Ty(_) => CanonicalVarKind::Ty(ui), + CanonicalVarKind::Ty { universe: _, sub_root } => { + CanonicalVarKind::Ty { universe: ui, sub_root } + } CanonicalVarKind::Region(_) => CanonicalVarKind::Region(ui), CanonicalVarKind::Const(_) => CanonicalVarKind::Const(ui), @@ -297,7 +303,7 @@ impl CanonicalVarValues { var_values: cx.mk_args_from_iter(infos.iter().enumerate().map( |(i, info)| -> I::GenericArg { match info.kind { - CanonicalVarKind::Ty(_) + CanonicalVarKind::Ty { .. } | CanonicalVarKind::Int | CanonicalVarKind::Float | CanonicalVarKind::PlaceholderTy(_) => { diff --git a/compiler/rustc_type_ir/src/infer_ctxt.rs b/compiler/rustc_type_ir/src/infer_ctxt.rs index 3974d160e26e..70467622b8eb 100644 --- a/compiler/rustc_type_ir/src/infer_ctxt.rs +++ b/compiler/rustc_type_ir/src/infer_ctxt.rs @@ -149,6 +149,7 @@ pub trait InferCtxtLike: Sized { fn universe_of_ct(&self, ct: ty::ConstVid) -> Option; fn root_ty_var(&self, var: ty::TyVid) -> ty::TyVid; + fn sub_root_ty_var(&self, var: ty::TyVid) -> ty::TyVid; fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid; fn opportunistic_resolve_ty_var(&self, vid: ty::TyVid) -> ::Ty; From 88dd764b74217fdbd8d7c7d58273ec8dcd8a1ccc Mon Sep 17 00:00:00 2001 From: lcnr Date: Wed, 7 May 2025 02:18:14 +0000 Subject: [PATCH 4/5] bless mir-opt --- ...s_of_reborrow.SimplifyCfg-initial.after.mir | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/mir-opt/address_of.address_of_reborrow.SimplifyCfg-initial.after.mir b/tests/mir-opt/address_of.address_of_reborrow.SimplifyCfg-initial.after.mir index 5fc77f95eaf7..397bc48170c3 100644 --- a/tests/mir-opt/address_of.address_of_reborrow.SimplifyCfg-initial.after.mir +++ b/tests/mir-opt/address_of.address_of_reborrow.SimplifyCfg-initial.after.mir @@ -1,30 +1,30 @@ // MIR for `address_of_reborrow` after SimplifyCfg-initial | User Type Annotations -| 0: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:8:10: 8:18, inferred_ty: *const [i32; 10] +| 0: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:8:10: 8:18, inferred_ty: *const [i32; 10] | 1: user_ty: Canonical { value: Ty(*const dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:10:10: 10:25, inferred_ty: *const dyn std::marker::Send -| 2: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:14:12: 14:20, inferred_ty: *const [i32; 10] -| 3: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:14:12: 14:20, inferred_ty: *const [i32; 10] +| 2: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:14:12: 14:20, inferred_ty: *const [i32; 10] +| 3: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:14:12: 14:20, inferred_ty: *const [i32; 10] | 4: user_ty: Canonical { value: Ty(*const [i32; 10]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:15:12: 15:28, inferred_ty: *const [i32; 10] | 5: user_ty: Canonical { value: Ty(*const [i32; 10]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:15:12: 15:28, inferred_ty: *const [i32; 10] | 6: user_ty: Canonical { value: Ty(*const dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:16:12: 16:27, inferred_ty: *const dyn std::marker::Send | 7: user_ty: Canonical { value: Ty(*const dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:16:12: 16:27, inferred_ty: *const dyn std::marker::Send | 8: user_ty: Canonical { value: Ty(*const [i32]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:17:12: 17:24, inferred_ty: *const [i32] | 9: user_ty: Canonical { value: Ty(*const [i32]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:17:12: 17:24, inferred_ty: *const [i32] -| 10: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:19:10: 19:18, inferred_ty: *const [i32; 10] +| 10: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:19:10: 19:18, inferred_ty: *const [i32; 10] | 11: user_ty: Canonical { value: Ty(*const dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:21:10: 21:25, inferred_ty: *const dyn std::marker::Send -| 12: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:24:12: 24:20, inferred_ty: *const [i32; 10] -| 13: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:24:12: 24:20, inferred_ty: *const [i32; 10] +| 12: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:24:12: 24:20, inferred_ty: *const [i32; 10] +| 13: user_ty: Canonical { value: Ty(*const ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:24:12: 24:20, inferred_ty: *const [i32; 10] | 14: user_ty: Canonical { value: Ty(*const [i32; 10]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:25:12: 25:28, inferred_ty: *const [i32; 10] | 15: user_ty: Canonical { value: Ty(*const [i32; 10]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:25:12: 25:28, inferred_ty: *const [i32; 10] | 16: user_ty: Canonical { value: Ty(*const dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:26:12: 26:27, inferred_ty: *const dyn std::marker::Send | 17: user_ty: Canonical { value: Ty(*const dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:26:12: 26:27, inferred_ty: *const dyn std::marker::Send | 18: user_ty: Canonical { value: Ty(*const [i32]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:27:12: 27:24, inferred_ty: *const [i32] | 19: user_ty: Canonical { value: Ty(*const [i32]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:27:12: 27:24, inferred_ty: *const [i32] -| 20: user_ty: Canonical { value: Ty(*mut ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:29:10: 29:16, inferred_ty: *mut [i32; 10] +| 20: user_ty: Canonical { value: Ty(*mut ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:29:10: 29:16, inferred_ty: *mut [i32; 10] | 21: user_ty: Canonical { value: Ty(*mut dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:31:10: 31:23, inferred_ty: *mut dyn std::marker::Send -| 22: user_ty: Canonical { value: Ty(*mut ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:34:12: 34:18, inferred_ty: *mut [i32; 10] -| 23: user_ty: Canonical { value: Ty(*mut ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty(General(U0)) }] }, span: $DIR/address_of.rs:34:12: 34:18, inferred_ty: *mut [i32; 10] +| 22: user_ty: Canonical { value: Ty(*mut ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:34:12: 34:18, inferred_ty: *mut [i32; 10] +| 23: user_ty: Canonical { value: Ty(*mut ^0), max_universe: U0, variables: [CanonicalVarInfo { kind: Ty { universe: U0, sub_root: 0 } }] }, span: $DIR/address_of.rs:34:12: 34:18, inferred_ty: *mut [i32; 10] | 24: user_ty: Canonical { value: Ty(*mut [i32; 10]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:35:12: 35:26, inferred_ty: *mut [i32; 10] | 25: user_ty: Canonical { value: Ty(*mut [i32; 10]), max_universe: U0, variables: [] }, span: $DIR/address_of.rs:35:12: 35:26, inferred_ty: *mut [i32; 10] | 26: user_ty: Canonical { value: Ty(*mut dyn std::marker::Send), max_universe: U0, variables: [CanonicalVarInfo { kind: Region(U0) }] }, span: $DIR/address_of.rs:36:12: 36:25, inferred_ty: *mut dyn std::marker::Send From 53c24cb8daaa7ad000457681335208bb47375d7c Mon Sep 17 00:00:00 2001 From: lcnr Date: Wed, 7 May 2025 17:09:39 +0000 Subject: [PATCH 5/5] rarw --- .../src/infer/canonical/query_response.rs | 39 ++----------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/compiler/rustc_infer/src/infer/canonical/query_response.rs b/compiler/rustc_infer/src/infer/canonical/query_response.rs index 6ce4e601aa82..2e0e12e0e68f 100644 --- a/compiler/rustc_infer/src/infer/canonical/query_response.rs +++ b/compiler/rustc_infer/src/infer/canonical/query_response.rs @@ -13,9 +13,7 @@ use std::iter; use rustc_index::{Idx, IndexVec}; use rustc_middle::arena::ArenaAllocatable; use rustc_middle::mir::ConstraintCategory; -use rustc_middle::ty::{ - self, BoundVar, CanonicalVarKind, GenericArg, GenericArgKind, Ty, TyCtxt, TypeFoldable, -}; +use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt, TypeFoldable}; use rustc_middle::{bug, span_bug}; use tracing::{debug, instrument}; @@ -422,16 +420,7 @@ impl<'tcx> InferCtxt<'tcx> { for (original_value, result_value) in iter::zip(&original_values.var_values, result_values) { match result_value.unpack() { - GenericArgKind::Type(result_value) => { - // e.g., here `result_value` might be `?0` in the example above... - if let ty::Bound(debruijn, b) = *result_value.kind() { - // ...in which case we would set `canonical_vars[0]` to `Some(?U)`. - - // We only allow a `ty::INNERMOST` index in generic parameters. - assert_eq!(debruijn, ty::INNERMOST); - opt_values[b.var] = Some(*original_value); - } - } + GenericArgKind::Type(_) => {} GenericArgKind::Lifetime(result_value) => { // e.g., here `result_value` might be `'?1` in the example above... if let ty::ReBound(debruijn, br) = result_value.kind() { @@ -457,7 +446,7 @@ impl<'tcx> InferCtxt<'tcx> { // Create result arguments: if we found a value for a // given variable in the loop above, use that. Otherwise, use // a fresh inference variable. - let mut var_values = Vec::new(); + let mut var_values = Vec::with_capacity(query_response.variables.len()); for (index, info) in query_response.variables.iter().enumerate() { let value = if info.universe() != ty::UniverseIndex::ROOT { // A variable from inside a binder of the query. While ideally these shouldn't @@ -470,27 +459,7 @@ impl<'tcx> InferCtxt<'tcx> { // We need to still make sure to register any subtype relations returned by the // query. match opt_values[BoundVar::new(index)] { - Some(v) => { - if let CanonicalVarKind::Ty { universe: _, sub_root } = info.kind { - if let Some(prev) = var_values.get(sub_root.as_usize()) { - // We cannot simply assume that previous `var_values` - // are inference variables, see the comment in - // `instantiate_canonical_var`. - let v = self.shallow_resolve(v.expect_ty()); - let prev = self.shallow_resolve(prev.expect_ty()); - match (v.kind(), prev.kind()) { - ( - &ty::Infer(ty::TyVar(vid)), - &ty::Infer(ty::TyVar(sub_root)), - ) => { - self.inner.borrow_mut().type_variables().sub(vid, sub_root) - } - _ => {} - } - } - } - v - } + Some(v) => v, None => self.instantiate_canonical_var(cause.span, info, &var_values, |u| { universe_map[u.as_usize()] }),