Skip to content

Support return-type bounds on associated methods from supertraits #111161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion compiler/rustc_hir_analysis/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ hir_analysis_return_type_notation_equality_bound =
return type notation is not allowed to use type equality

hir_analysis_return_type_notation_missing_method =
cannot find associated function `{$assoc_name}` in trait `{$trait_name}`
cannot find associated function `{$assoc_name}` for `{$ty_name}`

hir_analysis_return_type_notation_conflicting_bound =
ambiguous associated function `{$assoc_name}` for `{$ty_name}`
.note = `{$assoc_name}` is declared in two supertraits: `{$first_bound}` and `{$second_bound}`

hir_analysis_placeholder_not_allowed_item_signatures = the placeholder `_` is not allowed within types on item signatures for {$kind}
.label = not allowed in type signatures
Expand Down
57 changes: 49 additions & 8 deletions compiler/rustc_hir_analysis/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {

/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
/// named `assoc_name` into ty::Bounds. Ignore the rest.
pub(crate) fn compute_bounds_that_match_assoc_type(
pub(crate) fn compute_bounds_that_match_assoc_item(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
Expand All @@ -1051,7 +1051,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
for ast_bound in ast_bounds {
if let Some(trait_ref) = ast_bound.trait_ref()
&& let Some(trait_did) = trait_ref.trait_def_id()
&& self.tcx().trait_may_define_assoc_type(trait_did, assoc_name)
&& self.tcx().trait_may_define_assoc_item(trait_did, assoc_name)
{
result.push(ast_bound.clone());
}
Expand Down Expand Up @@ -1118,11 +1118,12 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
) {
trait_ref
} else {
return Err(tcx.sess.emit_err(crate::errors::ReturnTypeNotationMissingMethod {
span: binding.span,
trait_name: tcx.item_name(trait_ref.def_id()),
assoc_name: binding.item_name.name,
}));
self.one_bound_for_assoc_method(
traits::supertraits(tcx, trait_ref),
trait_ref.print_only_trait_path(),
binding.item_name,
path_span,
)?
}
} else if self.trait_defines_associated_item_named(
trait_ref.def_id(),
Expand Down Expand Up @@ -1922,7 +1923,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
let param_name = tcx.hir().ty_param_name(ty_param_def_id);
self.one_bound_for_assoc_type(
|| {
traits::transitive_bounds_that_define_assoc_type(
traits::transitive_bounds_that_define_assoc_item(
tcx,
predicates.iter().filter_map(|(p, _)| {
Some(p.to_opt_poly_trait_pred()?.map_bound(|t| t.trait_ref))
Expand Down Expand Up @@ -2057,6 +2058,46 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
Ok(bound)
}

#[instrument(level = "debug", skip(self, all_candidates, ty_name), ret)]
fn one_bound_for_assoc_method(
&self,
all_candidates: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
ty_name: impl Display,
assoc_name: Ident,
span: Span,
) -> Result<ty::PolyTraitRef<'tcx>, ErrorGuaranteed> {
let mut matching_candidates = all_candidates.filter(|r| {
self.trait_defines_associated_item_named(r.def_id(), ty::AssocKind::Fn, assoc_name)
});

let candidate = match matching_candidates.next() {
Some(candidate) => candidate,
None => {
return Err(self.tcx().sess.emit_err(
crate::errors::ReturnTypeNotationMissingMethod {
span,
ty_name: ty_name.to_string(),
assoc_name: assoc_name.name,
},
));
}
};

if let Some(conflicting_candidate) = matching_candidates.next() {
return Err(self.tcx().sess.emit_err(
crate::errors::ReturnTypeNotationConflictingBound {
span,
ty_name: ty_name.to_string(),
assoc_name: assoc_name.name,
first_bound: candidate.print_only_trait_path(),
second_bound: conflicting_candidate.print_only_trait_path(),
},
));
}

Ok(candidate)
}

// Create a type from a path to an associated type or to an enum variant.
// For a path `A::B::C::D`, `qself_ty` and `qself_def` are the type and def for `A::B::C`
// and item_segment is the path segment for `D`. We return a type and a def for
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_analysis/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ pub fn provide(providers: &mut Providers) {
explicit_predicates_of: predicates_of::explicit_predicates_of,
super_predicates_of: predicates_of::super_predicates_of,
implied_predicates_of: predicates_of::implied_predicates_of,
super_predicates_that_define_assoc_type:
predicates_of::super_predicates_that_define_assoc_type,
super_predicates_that_define_assoc_item:
predicates_of::super_predicates_that_define_assoc_item,
trait_explicit_predicates_and_bounds: predicates_of::trait_explicit_predicates_and_bounds,
type_param_predicates: predicates_of::type_param_predicates,
trait_def,
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_hir_analysis/src/collect/predicates_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ pub(super) fn super_predicates_of(
implied_predicates_with_filter(tcx, trait_def_id.to_def_id(), PredicateFilter::SelfOnly)
}

pub(super) fn super_predicates_that_define_assoc_type(
pub(super) fn super_predicates_that_define_assoc_item(
tcx: TyCtxt<'_>,
(trait_def_id, assoc_name): (DefId, Ident),
) -> ty::GenericPredicates<'_> {
Expand Down Expand Up @@ -640,7 +640,7 @@ pub(super) fn implied_predicates_with_filter(
),
PredicateFilter::SelfThatDefines(assoc_name) => (
// Convert the bounds that follow the colon (or equal) that reference the associated name
icx.astconv().compute_bounds_that_match_assoc_type(self_param_ty, bounds, assoc_name),
icx.astconv().compute_bounds_that_match_assoc_item(self_param_ty, bounds, assoc_name),
// Include where clause bounds for `Self` that reference the associated name
icx.type_parameter_bounds_in_generics(
generics,
Expand Down Expand Up @@ -819,7 +819,7 @@ impl<'tcx> ItemCtxt<'tcx> {
hir::GenericBound::Trait(poly_trait_ref, _) => {
let trait_ref = &poly_trait_ref.trait_ref;
if let Some(trait_did) = trait_ref.trait_def_id() {
self.tcx.trait_may_define_assoc_type(trait_did, assoc_name)
self.tcx.trait_may_define_assoc_item(trait_did, assoc_name)
} else {
false
}
Expand Down
56 changes: 33 additions & 23 deletions compiler/rustc_hir_analysis/src/collect/resolve_bound_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1652,27 +1652,28 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
if binding.gen_args.parenthesized == hir::GenericArgsParentheses::ReturnTypeNotation {
let bound_vars = if let Some(type_def_id) = type_def_id
&& self.tcx.def_kind(type_def_id) == DefKind::Trait
// FIXME(return_type_notation): We could bound supertrait methods.
&& let Some(assoc_fn) = self
.tcx
.associated_items(type_def_id)
.find_by_name_and_kind(self.tcx, binding.ident, ty::AssocKind::Fn, type_def_id)
&& let Some((mut bound_vars, assoc_fn)) =
BoundVarContext::supertrait_hrtb_vars(
self.tcx,
type_def_id,
binding.ident,
ty::AssocKind::Fn,
)
{
self.tcx
.generics_of(assoc_fn.def_id)
.params
.iter()
.map(|param| match param.kind {
bound_vars.extend(self.tcx.generics_of(assoc_fn.def_id).params.iter().map(
|param| match param.kind {
ty::GenericParamDefKind::Lifetime => ty::BoundVariableKind::Region(
ty::BoundRegionKind::BrNamed(param.def_id, param.name),
),
ty::GenericParamDefKind::Type { .. } => ty::BoundVariableKind::Ty(
ty::BoundTyKind::Param(param.def_id, param.name),
),
ty::GenericParamDefKind::Const { .. } => ty::BoundVariableKind::Const,
})
.chain(self.tcx.fn_sig(assoc_fn.def_id).subst_identity().bound_vars())
.collect()
},
));
bound_vars
.extend(self.tcx.fn_sig(assoc_fn.def_id).subst_identity().bound_vars());
bound_vars
} else {
self.tcx.sess.delay_span_bug(
binding.ident.span,
Expand All @@ -1689,8 +1690,13 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
});
});
} else if let Some(type_def_id) = type_def_id {
let bound_vars =
BoundVarContext::supertrait_hrtb_vars(self.tcx, type_def_id, binding.ident);
let bound_vars = BoundVarContext::supertrait_hrtb_vars(
self.tcx,
type_def_id,
binding.ident,
ty::AssocKind::Type,
)
.map(|(bound_vars, _)| bound_vars);
self.with(scope, |this| {
let scope = Scope::Supertrait {
bound_vars: bound_vars.unwrap_or_default(),
Expand Down Expand Up @@ -1720,11 +1726,15 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
def_id: DefId,
assoc_name: Ident,
) -> Option<Vec<ty::BoundVariableKind>> {
let trait_defines_associated_type_named = |trait_def_id: DefId| {
tcx.associated_items(trait_def_id)
.find_by_name_and_kind(tcx, assoc_name, ty::AssocKind::Type, trait_def_id)
.is_some()
assoc_kind: ty::AssocKind,
) -> Option<(Vec<ty::BoundVariableKind>, &'tcx ty::AssocItem)> {
let trait_defines_associated_item_named = |trait_def_id: DefId| {
tcx.associated_items(trait_def_id).find_by_name_and_kind(
tcx,
assoc_name,
assoc_kind,
trait_def_id,
)
};

use smallvec::{smallvec, SmallVec};
Expand All @@ -1742,10 +1752,10 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
_ => break None,
}

if trait_defines_associated_type_named(def_id) {
break Some(bound_vars.into_iter().collect());
if let Some(assoc_item) = trait_defines_associated_item_named(def_id) {
break Some((bound_vars.into_iter().collect(), assoc_item));
}
let predicates = tcx.super_predicates_that_define_assoc_type((def_id, assoc_name));
let predicates = tcx.super_predicates_that_define_assoc_item((def_id, assoc_name));
let obligations = predicates.predicates.iter().filter_map(|&(pred, _)| {
let bound_predicate = pred.kind();
match bound_predicate.skip_binder() {
Expand Down
16 changes: 14 additions & 2 deletions compiler/rustc_hir_analysis/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rustc_errors::{
MultiSpan,
};
use rustc_macros::{Diagnostic, Subdiagnostic};
use rustc_middle::ty::Ty;
use rustc_middle::ty::{self, print::TraitRefPrintOnlyTraitPath, Ty};
use rustc_span::{symbol::Ident, Span, Symbol};

#[derive(Diagnostic)]
Expand Down Expand Up @@ -512,10 +512,22 @@ pub(crate) struct ReturnTypeNotationEqualityBound {
pub(crate) struct ReturnTypeNotationMissingMethod {
#[primary_span]
pub span: Span,
pub trait_name: Symbol,
pub ty_name: String,
pub assoc_name: Symbol,
}

#[derive(Diagnostic)]
#[diag(hir_analysis_return_type_notation_conflicting_bound)]
#[note]
pub(crate) struct ReturnTypeNotationConflictingBound<'tcx> {
#[primary_span]
pub span: Span,
pub ty_name: String,
pub assoc_name: Symbol,
pub first_bound: ty::Binder<'tcx, TraitRefPrintOnlyTraitPath<'tcx>>,
pub second_bound: ty::Binder<'tcx, TraitRefPrintOnlyTraitPath<'tcx>>,
}

#[derive(Diagnostic)]
#[diag(hir_analysis_placeholder_not_allowed_item_signatures, code = "E0121")]
pub(crate) struct PlaceholderNotAllowedItemSignatures {
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,11 @@ pub fn transitive_bounds<'tcx>(
}

/// A specialized variant of `elaborate` that only elaborates trait references that may
/// define the given associated type `assoc_name`. It uses the
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
/// define the given associated item with the name `assoc_name`. It uses the
/// `super_predicates_that_define_assoc_item` query to avoid enumerating super-predicates that
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
pub fn transitive_bounds_that_define_assoc_item<'tcx>(
tcx: TyCtxt<'tcx>,
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
assoc_name: Ident,
Expand All @@ -393,7 +393,7 @@ pub fn transitive_bounds_that_define_assoc_type<'tcx>(
let anon_trait_ref = tcx.anonymize_bound_vars(trait_ref);
if visited.insert(anon_trait_ref) {
let super_predicates =
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), assoc_name));
tcx.super_predicates_that_define_assoc_item((trait_ref.def_id(), assoc_name));
for (super_predicate, _) in super_predicates.predicates {
let subst_predicate = super_predicate.subst_supertrait(tcx, &trait_ref);
if let Some(binder) = subst_predicate.to_opt_poly_trait_pred() {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ rustc_queries! {
/// returns the full set of predicates. If `Some<Ident>`, then the query returns only the
/// subset of super-predicates that reference traits that define the given associated type.
/// This is used to avoid cycles in resolving types like `T::Item`.
query super_predicates_that_define_assoc_type(key: (DefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
query super_predicates_that_define_assoc_item(key: (DefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the super traits of `{}` with associated type name `{}`",
tcx.def_path_str(key.0),
key.1
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1567,11 +1567,11 @@ impl<'tcx> TyCtxt<'tcx> {

/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
pub fn trait_may_define_assoc_item(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
self.super_traits_of(trait_def_id).any(|trait_did| {
self.associated_items(trait_did)
.find_by_name_and_kind(self, assoc_name, ty::AssocKind::Type, trait_did)
.is_some()
.filter_by_name_unhygienic(assoc_name.name)
.any(|item| self.hygienic_eq(assoc_name, item.ident(self), trait_did))
})
}

Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,12 @@ macro_rules! define_print_and_forward_display {
#[derive(Copy, Clone, TypeFoldable, TypeVisitable, Lift)]
pub struct TraitRefPrintOnlyTraitPath<'tcx>(ty::TraitRef<'tcx>);

impl<'tcx> rustc_errors::IntoDiagnosticArg for TraitRefPrintOnlyTraitPath<'tcx> {
fn into_diagnostic_arg(self) -> rustc_errors::DiagnosticArgValue<'static> {
self.to_string().into_diagnostic_arg()
}
}

impl<'tcx> fmt::Debug for TraitRefPrintOnlyTraitPath<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub use self::util::elaborate;
pub use self::util::{expand_trait_aliases, TraitAliasExpander};
pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices};
pub use self::util::{
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item,
SupertraitDefIds,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ trait Trait {
}

fn bar<T: Trait<methid(): Send>>() {}
//~^ ERROR cannot find associated function `methid` in trait `Trait`
//~^ ERROR cannot find associated function `methid` for `Trait`

fn main() {}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ LL | #![feature(return_type_notation, async_fn_in_trait)]
= note: see issue #109417 <https://github.com/rust-lang/rust/issues/109417> for more information
= note: `#[warn(incomplete_features)]` on by default

error: cannot find associated function `methid` in trait `Trait`
error: cannot find associated function `methid` for `Trait`
--> $DIR/missing.rs:10:17
|
LL | fn bar<T: Trait<methid(): Send>>() {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// edition:2021

#![feature(async_fn_in_trait, return_type_notation)]
//~^ WARN the feature `return_type_notation` is incomplete

trait Super1<'a> {
async fn test();
}
impl Super1<'_> for () {
async fn test() {}
}

trait Super2 {
async fn test();
}
impl Super2 for () {
async fn test() {}
}

trait Foo: for<'a> Super1<'a> + Super2 {}
impl Foo for () {}

fn test<T>()
where
T: Foo<test(): Send>,
//~^ ERROR ambiguous associated function `test` for `Foo`
{
}

fn main() {
test::<()>();
}
Loading