Skip to content

Eliminate ObligationCauseData. #73983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/librustc_infer/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
cause: &ObligationCause<'tcx>,
exp_found: Option<ty::error::ExpectedFound<Ty<'tcx>>>,
) {
match cause.code {
match *cause.code() {
ObligationCauseCode::Pattern { origin_expr: true, span: Some(span), root_ty } => {
let ty = self.resolve_vars_if_possible(&root_ty);
if ty.is_suggestable() {
Expand Down Expand Up @@ -2058,7 +2058,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {
fn as_failure_code(&self, terr: &TypeError<'tcx>) -> FailureCode {
use self::FailureCode::*;
use crate::traits::ObligationCauseCode::*;
match self.code {
match self.code() {
CompareImplMethodObligation { .. } => Error0308("method not compatible with trait"),
CompareImplTypeObligation { .. } => Error0308("type not compatible with trait"),
MatchExpressionArm(box MatchExpressionArmCause { source, .. }) => {
Expand Down Expand Up @@ -2097,7 +2097,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {

fn as_requirement_str(&self) -> &'static str {
use crate::traits::ObligationCauseCode::*;
match self.code {
match self.code() {
CompareImplMethodObligation { .. } => "method type is compatible with trait",
CompareImplTypeObligation { .. } => "associated type is compatible with trait",
ExprAssignable => "expression is assignable",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl NiceRegionError<'me, 'tcx> {
format!("trait `{}` defined here", self.tcx().def_path_str(trait_def_id)),
);

let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = cause.code {
let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = *cause.code() {
err.span_label(span, "doesn't satisfy where-clause");
err.span_label(
self.tcx().def_span(def_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
ValuePairs::Types(sub_expected_found),
ValuePairs::Types(sup_expected_found),
CompareImplMethodObligation { trait_item_def_id, .. },
) = (&sub_trace.values, &sup_trace.values, &sub_trace.cause.code)
) = (&sub_trace.values, &sup_trace.values, sub_trace.cause.code())
{
if sup_expected_found == sub_expected_found {
self.emit_err(
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_infer/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ impl<'tcx> SubregionOrigin<'tcx> {
where
F: FnOnce() -> Self,
{
match cause.code {
match *cause.code() {
traits::ObligationCauseCode::ReferenceOutlivesReferent(ref_type) => {
SubregionOrigin::ReferenceOutlivesReferent(ref_type, cause.span)
}
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_infer/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub type TraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>;

// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
#[cfg(target_arch = "x86_64")]
static_assert_size!(PredicateObligation<'_>, 48);
static_assert_size!(PredicateObligation<'_>, 64);

pub type Obligations<'tcx, O> = Vec<Obligation<'tcx, O>>;
pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;
Expand Down
68 changes: 32 additions & 36 deletions src/librustc_middle/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ use rustc_span::{Span, DUMMY_SP};
use smallvec::SmallVec;

use std::borrow::Cow;
use std::fmt;
use std::ops::Deref;
//use std::fmt; // njn: temp
use std::rc::Rc;

pub use self::select::{EvaluationCache, EvaluationResult, OverflowError, SelectionCache};
Expand Down Expand Up @@ -81,38 +80,14 @@ pub enum Reveal {

/// The reason why we incurred this obligation; used for error reporting.
///
/// As the happy path does not care about this struct, storing this on the heap
/// ends up increasing performance.
/// Non-dummy `ObligationCauseCode`s are stored on the heap. This gives best
/// trade-off between keeping the type small (which makes copies cheaper) while
/// not doing too many heap allocations.
///
/// We do not want to intern this as there are a lot of obligation causes which
/// only live for a short period of time.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct ObligationCause<'tcx> {
/// `None` for `ObligationCause::dummy`, `Some` otherwise.
data: Option<Rc<ObligationCauseData<'tcx>>>,
}

const DUMMY_OBLIGATION_CAUSE_DATA: ObligationCauseData<'static> =
ObligationCauseData { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation };

// Correctly format `ObligationCause::dummy`.
impl<'tcx> fmt::Debug for ObligationCause<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ObligationCauseData::fmt(self, f)
}
}

impl Deref for ObligationCause<'tcx> {
type Target = ObligationCauseData<'tcx>;

#[inline(always)]
fn deref(&self) -> &Self::Target {
self.data.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_DATA)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ObligationCauseData<'tcx> {
pub struct ObligationCause<'tcx> {
pub span: Span,

/// The ID of the fn body that triggered this obligation. This is
Expand All @@ -123,17 +98,38 @@ pub struct ObligationCauseData<'tcx> {
/// information.
pub body_id: hir::HirId,

pub code: ObligationCauseCode<'tcx>,
/// `None` for `DUMMY_OBLIGATION_CAUSE_CODE` (a very common case), `Some`
/// otherwise.
code: Option<Rc<ObligationCauseCode<'tcx>>>,
}

const DUMMY_OBLIGATION_CAUSE_CODE: ObligationCauseCode<'static> = MiscObligation;

// Correctly format `DUMMY_OBLIGATION_CAUSE_CODE`.
// njn: fix this
//impl<'tcx> fmt::Debug for ObligationCauseData<'tcx> {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// ObligationCauseCode::fmt(&self.code2, f)
// }
//}

impl<'tcx> ObligationCause<'tcx> {
#[inline(always)]
pub fn code(&self) -> &ObligationCauseCode<'tcx> {
self.code.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_CODE)
}

#[inline]
pub fn new(
span: Span,
body_id: hir::HirId,
code: ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
ObligationCause { data: Some(Rc::new(ObligationCauseData { span, body_id, code })) }
if code == DUMMY_OBLIGATION_CAUSE_CODE {
ObligationCause { span, body_id, code: None }
} else {
ObligationCause { span, body_id, code: Some(Rc::new(code)) }
}
}

pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
Expand All @@ -146,15 +142,15 @@ impl<'tcx> ObligationCause<'tcx> {

#[inline(always)]
pub fn dummy() -> ObligationCause<'tcx> {
ObligationCause { data: None }
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
}

pub fn make_mut(&mut self) -> &mut ObligationCauseData<'tcx> {
Rc::make_mut(self.data.get_or_insert_with(|| Rc::new(DUMMY_OBLIGATION_CAUSE_DATA)))
pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> {
Rc::make_mut(self.code.get_or_insert_with(|| Rc::new(DUMMY_OBLIGATION_CAUSE_CODE)))
}

pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
match self.code {
match *self.code() {
ObligationCauseCode::CompareImplMethodObligation { .. }
| ObligationCauseCode::MainFunctionType
| ObligationCauseCode::StartFunctionType => {
Expand Down
3 changes: 2 additions & 1 deletion src/librustc_middle/traits/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ impl<'a, 'tcx> Lift<'tcx> for traits::DerivedObligationCause<'a> {
impl<'a, 'tcx> Lift<'tcx> for traits::ObligationCause<'a> {
type Lifted = traits::ObligationCause<'tcx>;
fn lift_to_tcx(&self, tcx: TyCtxt<'tcx>) -> Option<Self::Lifted> {
tcx.lift(&self.code).map(|code| traits::ObligationCause::new(self.span, self.body_id, code))
tcx.lift(self.code())
.map(|code| traits::ObligationCause::new(self.span, self.body_id, code))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/librustc_middle/ty/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ impl<T> Trait<T> for X {
proj_ty,
values,
body_owner_def_id,
&cause.code,
cause.code(),
);
}
(_, ty::Projection(proj_ty)) => {
Expand Down
25 changes: 14 additions & 11 deletions src/librustc_trait_selection/traits/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
self.note_obligation_cause_code(
&mut err,
&obligation.predicate,
&obligation.cause.code,
obligation.cause.code(),
&mut vec![],
);

Expand Down Expand Up @@ -242,7 +242,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
item_name,
impl_item_def_id,
trait_item_def_id,
} = obligation.cause.code
} = *obligation.cause.code()
{
self.report_extra_impl_obligation(
span,
Expand All @@ -263,7 +263,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
}
let trait_ref = trait_predicate.to_poly_trait_ref();
let (post_message, pre_message, type_def) = self
.get_parent_trait_ref(&obligation.cause.code)
.get_parent_trait_ref(obligation.cause.code())
.map(|(t, s)| {
(
format!(" in `{}`", t),
Expand Down Expand Up @@ -350,7 +350,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
}

let explanation =
if obligation.cause.code == ObligationCauseCode::MainFunctionType {
if *obligation.cause.code() == ObligationCauseCode::MainFunctionType {
"consider using `()`, or a `Result`".to_owned()
} else {
format!(
Expand Down Expand Up @@ -403,7 +403,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
}

self.suggest_dereferences(&obligation, &mut err, &trait_ref, points_at_arg);
self.suggest_borrow_on_unsized_slice(&obligation.cause.code, &mut err);
self.suggest_borrow_on_unsized_slice(obligation.cause.code(), &mut err);
self.suggest_fn_call(&obligation, &mut err, &trait_ref, points_at_arg);
self.suggest_remove_reference(&obligation, &mut err, &trait_ref);
self.suggest_semicolon_removal(&obligation, &mut err, span, &trait_ref);
Expand Down Expand Up @@ -1179,7 +1179,7 @@ impl<'a, 'tcx> InferCtxtPrivExt<'tcx> for InferCtxt<'a, 'tcx> {
normalized_ty, data.ty
);

let is_normalized_ty_expected = match &obligation.cause.code {
let is_normalized_ty_expected = match obligation.cause.code() {
ObligationCauseCode::ItemObligation(_)
| ObligationCauseCode::BindingObligation(_, _)
| ObligationCauseCode::ObjectCastObligation(_) => false,
Expand Down Expand Up @@ -1435,7 +1435,10 @@ impl<'a, 'tcx> InferCtxtPrivExt<'tcx> for InferCtxt<'a, 'tcx> {

debug!(
"maybe_report_ambiguity(predicate={:?}, obligation={:?} body_id={:?}, code={:?})",
predicate, obligation, body_id, obligation.cause.code,
predicate,
obligation,
body_id,
obligation.cause.code(),
);

// Ambiguity errors are often caused as fallout from earlier
Expand Down Expand Up @@ -1489,13 +1492,13 @@ impl<'a, 'tcx> InferCtxtPrivExt<'tcx> for InferCtxt<'a, 'tcx> {
}
let mut err = self.need_type_info_err(body_id, span, self_ty, ErrorCode::E0283);
err.note(&format!("cannot satisfy `{}`", predicate));
if let ObligationCauseCode::ItemObligation(def_id) = obligation.cause.code {
if let ObligationCauseCode::ItemObligation(def_id) = *obligation.cause.code() {
self.suggest_fully_qualified_path(&mut err, def_id, span, trait_ref.def_id());
} else if let (
Ok(ref snippet),
ObligationCauseCode::BindingObligation(ref def_id, _),
) =
(self.tcx.sess.source_map().span_to_snippet(span), &obligation.cause.code)
(self.tcx.sess.source_map().span_to_snippet(span), obligation.cause.code())
{
let generics = self.tcx.generics_of(*def_id);
if generics.params.iter().any(|p| p.name.as_str() != "Self")
Expand Down Expand Up @@ -1685,7 +1688,7 @@ impl<'a, 'tcx> InferCtxtPrivExt<'tcx> for InferCtxt<'a, 'tcx> {
self.note_obligation_cause_code(
err,
&obligation.predicate,
&obligation.cause.code,
obligation.cause.code(),
&mut vec![],
);
self.suggest_unsized_bound_if_applicable(err, obligation);
Expand All @@ -1698,7 +1701,7 @@ impl<'a, 'tcx> InferCtxtPrivExt<'tcx> for InferCtxt<'a, 'tcx> {
obligation: &PredicateObligation<'tcx>,
) {
let (pred, item_def_id, span) =
match (obligation.predicate.kind(), &obligation.cause.code.peel_derives()) {
match (obligation.predicate.kind(), &obligation.cause.code().peel_derives()) {
(
ty::PredicateKind::Trait(pred, _),
ObligationCauseCode::BindingObligation(item_def_id, span),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
self.describe_enclosure(obligation.cause.body_id).map(|s| s.to_owned()),
));

match obligation.cause.code {
match *obligation.cause.code() {
ObligationCauseCode::BuiltinDerivedObligation(..)
| ObligationCauseCode::ImplDerivedObligation(..)
| ObligationCauseCode::DerivedObligation(..) => {}
Expand All @@ -142,7 +142,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
}

if let ObligationCauseCode::ItemObligation(item)
| ObligationCauseCode::BindingObligation(item, _) = obligation.cause.code
| ObligationCauseCode::BindingObligation(item, _) = *obligation.cause.code()
{
// FIXME: maybe also have some way of handling methods
// from other traits? That would require name resolution,
Expand All @@ -155,7 +155,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
flags.push((sym::from_method, Some(method.to_string())));
}
}
if let Some((t, _)) = self.get_parent_trait_ref(&obligation.cause.code) {
if let Some((t, _)) = self.get_parent_trait_ref(obligation.cause.code()) {
flags.push((sym::parent_trait, Some(t)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
let param_env = obligation.param_env;
let body_id = obligation.cause.body_id;
let span = obligation.cause.span;
let real_trait_ref = match &obligation.cause.code {
let real_trait_ref = match obligation.cause.code() {
ObligationCauseCode::ImplDerivedObligation(cause)
| ObligationCauseCode::DerivedObligation(cause)
| ObligationCauseCode::BuiltinDerivedObligation(cause) => &cause.parent_trait_ref,
Expand Down Expand Up @@ -691,7 +691,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
let param_env = obligation.param_env;
let trait_ref = trait_ref.skip_binder();

if let ObligationCauseCode::ImplDerivedObligation(obligation) = &obligation.cause.code {
if let ObligationCauseCode::ImplDerivedObligation(obligation) = obligation.cause.code() {
// Try to apply the original trait binding obligation by borrowing.
let self_ty = trait_ref.self_ty();
let found = self_ty.to_string();
Expand Down Expand Up @@ -951,7 +951,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
obligation: &PredicateObligation<'tcx>,
trait_ref: &ty::Binder<ty::TraitRef<'tcx>>,
) -> bool {
match obligation.cause.code.peel_derives() {
match obligation.cause.code().peel_derives() {
// Only suggest `impl Trait` if the return type is unsized because it is `dyn Trait`.
ObligationCauseCode::SizedReturnType => {}
_ => return false,
Expand Down Expand Up @@ -1145,7 +1145,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
err: &mut DiagnosticBuilder<'_>,
obligation: &PredicateObligation<'tcx>,
) {
match obligation.cause.code.peel_derives() {
match obligation.cause.code().peel_derives() {
ObligationCauseCode::SizedReturnType => {}
_ => return,
}
Expand Down Expand Up @@ -1339,7 +1339,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
};
let mut generator = None;
let mut outer_generator = None;
let mut next_code = Some(&obligation.cause.code);
let mut next_code = Some(obligation.cause.code());
while let Some(code) = next_code {
debug!("maybe_note_obligation_cause_for_async_await: code={:?}", code);
match code {
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_trait_selection/traits/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub struct PendingPredicateObligation<'tcx> {

// `PendingPredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
#[cfg(target_arch = "x86_64")]
static_assert_size!(PendingPredicateObligation<'_>, 72);
static_assert_size!(PendingPredicateObligation<'_>, 88);

impl<'a, 'tcx> FulfillmentContext<'tcx> {
/// Creates a new fulfillment context.
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_trait_selection/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2124,7 +2124,7 @@ impl<'tcx> TraitObligationExt<'tcx> for TraitObligation<'tcx> {
// by using -Z verbose or just a CLI argument.
let derived_cause = DerivedObligationCause {
parent_trait_ref: obligation.predicate.to_poly_trait_ref(),
parent_code: Rc::new(obligation.cause.code.clone()),
parent_code: Rc::new(obligation.cause.code().clone()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parent_code: Rc::new(obligation.cause.code().clone()),
parent_code: obligation.cause.code(),

And also change parent_code to Option<Rc<ObligationCauseCode>>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, all occurences of parent_code currently reallocate

};
let derived_code = variant(derived_cause);
ObligationCause::new(obligation.cause.span, obligation.cause.body_id, derived_code)
Expand Down
Loading