Skip to content

On Fn arg mismatch for a fn path, suggest a closure #117805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 173 additions & 18 deletions compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use rustc_errors::{
ErrorGuaranteed, MultiSpan, Style, SuggestionStyle,
};
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::intravisit::Visitor;
use rustc_hir::is_range_literal;
Expand All @@ -36,7 +36,7 @@ use rustc_middle::ty::{
TypeSuperFoldable, TypeVisitableExt, TypeckResults,
};
use rustc_span::def_id::LocalDefId;
use rustc_span::symbol::{sym, Ident, Symbol};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
use rustc_span::{BytePos, DesugaringKind, ExpnKind, MacroKind, Span, DUMMY_SP};
use rustc_target::spec::abi;
use std::borrow::Cow;
Expand Down Expand Up @@ -222,6 +222,15 @@ pub trait TypeErrCtxtExt<'tcx> {
param_env: ty::ParamEnv<'tcx>,
) -> DiagnosticBuilder<'tcx, ErrorGuaranteed>;

fn note_conflicting_fn_args(
&self,
err: &mut Diagnostic,
cause: &ObligationCauseCode<'tcx>,
expected: Ty<'tcx>,
found: Ty<'tcx>,
param_env: ty::ParamEnv<'tcx>,
);

fn note_conflicting_closure_bounds(
&self,
cause: &ObligationCauseCode<'tcx>,
Expand Down Expand Up @@ -1034,7 +1043,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind else {
return;
};
let hir::def::Res::Local(hir_id) = path.res else {
let Res::Local(hir_id) = path.res else {
return;
};
let Some(hir::Node::Pat(pat)) = self.tcx.hir().find(hir_id) else {
Expand Down Expand Up @@ -1618,7 +1627,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}
}
if let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind
&& let hir::def::Res::Local(hir_id) = path.res
&& let Res::Local(hir_id) = path.res
&& let Some(hir::Node::Pat(binding)) = self.tcx.hir().find(hir_id)
&& let Some(hir::Node::Local(local)) = self.tcx.hir().find_parent(binding.hir_id)
&& let None = local.ty
Expand Down Expand Up @@ -2005,6 +2014,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
let signature_kind = format!("{argument_kind} signature");
err.note_expected_found(&signature_kind, expected_str, &signature_kind, found_str);

self.note_conflicting_fn_args(&mut err, cause, expected, found, param_env);
self.note_conflicting_closure_bounds(cause, &mut err);

if let Some(found_node) = found_node {
Expand All @@ -2014,6 +2024,151 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
err
}

fn note_conflicting_fn_args(
&self,
err: &mut Diagnostic,
cause: &ObligationCauseCode<'tcx>,
expected: Ty<'tcx>,
found: Ty<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) {
let ObligationCauseCode::FunctionArgumentObligation { arg_hir_id, .. } = cause else {
return;
};
let ty::FnPtr(expected) = expected.kind() else {
return;
};
let ty::FnPtr(found) = found.kind() else {
return;
};
let Some(Node::Expr(arg)) = self.tcx.hir().find(*arg_hir_id) else {
return;
};
let hir::ExprKind::Path(path) = arg.kind else {
return;
};
let expected_inputs = self.tcx.instantiate_bound_regions_with_erased(*expected).inputs();
let found_inputs = self.tcx.instantiate_bound_regions_with_erased(*found).inputs();
let both_tys = expected_inputs.iter().copied().zip(found_inputs.iter().copied());

let arg_expr = |infcx: &InferCtxt<'tcx>, name, expected: Ty<'tcx>, found: Ty<'tcx>| {
let (expected_ty, expected_refs) = get_deref_type_and_refs(expected);
let (found_ty, found_refs) = get_deref_type_and_refs(found);

if infcx.can_eq(param_env, found_ty, expected_ty) {
if found_refs.len() == expected_refs.len()
&& found_refs.iter().eq(expected_refs.iter())
{
name
} else if found_refs.len() > expected_refs.len() {
let refs = &found_refs[..found_refs.len() - expected_refs.len()];
if found_refs[..expected_refs.len()].iter().eq(expected_refs.iter()) {
format!(
"{}{name}",
refs.iter()
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
.collect::<Vec<_>>()
.join(""),
)
} else {
// The refs have different mutability.
format!(
"{}*{name}",
refs.iter()
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
.collect::<Vec<_>>()
.join(""),
)
}
} else if expected_refs.len() > found_refs.len() {
format!(
"{}{name}",
(0..(expected_refs.len() - found_refs.len()))
.map(|_| "*")
.collect::<Vec<_>>()
.join(""),
)
} else {
format!(
"{}{name}",
found_refs
.iter()
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
.chain(found_refs.iter().map(|_| "*".to_string()))
.collect::<Vec<_>>()
.join(""),
)
}
} else {
format!("/* {found} */")
}
};
let args_have_same_underlying_type = both_tys.clone().all(|(expected, found)| {
let (expected_ty, _) = get_deref_type_and_refs(expected);
let (found_ty, _) = get_deref_type_and_refs(found);
self.can_eq(param_env, found_ty, expected_ty)
});
let (closure_names, call_names): (Vec<_>, Vec<_>) = if args_have_same_underlying_type
&& !expected_inputs.is_empty()
&& expected_inputs.len() == found_inputs.len()
&& let Some(typeck) = &self.typeck_results
&& let Res::Def(_, fn_def_id) = typeck.qpath_res(&path, *arg_hir_id)
{
let closure: Vec<_> = self
.tcx
.fn_arg_names(fn_def_id)
.iter()
.enumerate()
.map(|(i, ident)| {
if ident.name.is_empty() || ident.name == kw::SelfLower {
format!("arg{i}")
} else {
format!("{ident}")
}
})
.collect();
let args = closure
.iter()
.zip(both_tys)
.map(|(name, (expected, found))| {
arg_expr(self.infcx, name.to_owned(), expected, found)
})
.collect();
(closure, args)
} else {
let closure_args = expected_inputs
.iter()
.enumerate()
.map(|(i, _)| format!("arg{i}"))
.collect::<Vec<_>>();
let call_args = both_tys
.enumerate()
.map(|(i, (expected, found))| {
arg_expr(self.infcx, format!("arg{i}"), expected, found)
})
.collect::<Vec<_>>();
(closure_args, call_args)
};
let closure_names: Vec<_> = closure_names
.into_iter()
.zip(expected_inputs.iter())
.map(|(name, ty)| {
format!(
"{name}{}",
if ty.has_infer_types() { String::new() } else { format!(": {ty}") }
)
})
.collect();
err.multipart_suggestion(
format!("consider wrapping the function in a closure"),
vec![
(arg.span.shrink_to_lo(), format!("|{}| ", closure_names.join(", "))),
(arg.span.shrink_to_hi(), format!("({})", call_names.join(", "))),
],
Applicability::MaybeIncorrect,
);
}

// Add a note if there are two `Fn`-family bounds that have conflicting argument
// requirements, which will always cause a closure to have a type error.
fn note_conflicting_closure_bounds(
Expand Down Expand Up @@ -3634,7 +3789,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}
}
if let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind
&& let hir::Path { res: hir::def::Res::Local(hir_id), .. } = path
&& let hir::Path { res: Res::Local(hir_id), .. } = path
&& let Some(hir::Node::Pat(binding)) = self.tcx.hir().find(*hir_id)
&& let parent_hir_id = self.tcx.hir().parent_id(binding.hir_id)
&& let Some(hir::Node::Local(local)) = self.tcx.hir().find(parent_hir_id)
Expand Down Expand Up @@ -3894,7 +4049,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
);

if let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind
&& let hir::Path { res: hir::def::Res::Local(hir_id), .. } = path
&& let hir::Path { res: Res::Local(hir_id), .. } = path
&& let Some(hir::Node::Pat(binding)) = self.tcx.hir().find(*hir_id)
&& let Some(parent) = self.tcx.hir().find_parent(binding.hir_id)
{
Expand Down Expand Up @@ -4349,17 +4504,6 @@ fn hint_missing_borrow<'tcx>(

let args = fn_decl.inputs.iter();

fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, Vec<hir::Mutability>) {
let mut refs = vec![];

while let ty::Ref(_, new_ty, mutbl) = ty.kind() {
ty = *new_ty;
refs.push(*mutbl);
}

(ty, refs)
}

let mut to_borrow = Vec::new();
let mut remove_borrow = Vec::new();

Expand Down Expand Up @@ -4519,7 +4663,7 @@ impl<'a, 'hir> hir::intravisit::Visitor<'hir> for ReplaceImplTraitVisitor<'a> {
fn visit_ty(&mut self, t: &'hir hir::Ty<'hir>) {
if let hir::TyKind::Path(hir::QPath::Resolved(
None,
hir::Path { res: hir::def::Res::Def(_, segment_did), .. },
hir::Path { res: Res::Def(_, segment_did), .. },
)) = t.kind
{
if self.param_did == *segment_did {
Expand Down Expand Up @@ -4652,3 +4796,14 @@ pub fn suggest_desugaring_async_fn_to_impl_future_in_trait<'tcx>(

Some(sugg)
}

fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, Vec<hir::Mutability>) {
let mut refs = vec![];

while let ty::Ref(_, new_ty, mutbl) = ty.kind() {
ty = *new_ty;
refs.push(*mutbl);
}

(ty, refs)
}
4 changes: 4 additions & 0 deletions tests/ui/generic-associated-types/bugs/issue-88382.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ note: required by a bound in `do_something`
|
LL | fn do_something<I: Iterable>(i: I, mut f: impl for<'a> Fn(&mut I::Iterator<'a>)) {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `do_something`
help: consider wrapping the function in a closure
|
LL | do_something(SomeImplementation(), |arg0: &mut std::iter::Empty<usize>| test(/* &mut <_ as Iterable>::Iterator<'_> */));
| ++++++++++++++++++++++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++

error: aborting due to 1 previous error

Expand Down
4 changes: 4 additions & 0 deletions tests/ui/intrinsics/const-eval-select-bad.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ LL | const_eval_select((true,), foo, baz);
found function signature `fn(i32) -> _`
note: required by a bound in `const_eval_select`
--> $SRC_DIR/core/src/intrinsics.rs:LL:COL
help: consider wrapping the function in a closure
|
LL | const_eval_select((true,), |arg0: bool| foo(/* i32 */), baz);
| ++++++++++++ +++++++++++

error: this argument must be a `const fn`
--> $DIR/const-eval-select-bad.rs:42:29
Expand Down
8 changes: 8 additions & 0 deletions tests/ui/mismatched_types/E0631.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ note: required by a bound in `foo`
|
LL | fn foo<F: Fn(usize)>(_: F) {}
| ^^^^^^^^^ required by this bound in `foo`
help: consider wrapping the function in a closure
|
LL | foo(|arg0: usize| f(/* u64 */));
| +++++++++++++ +++++++++++

error[E0631]: type mismatch in function arguments
--> $DIR/E0631.rs:10:9
Expand All @@ -67,6 +71,10 @@ note: required by a bound in `bar`
|
LL | fn bar<F: Fn<(usize,)>>(_: F) {}
| ^^^^^^^^^^^^ required by this bound in `bar`
help: consider wrapping the function in a closure
|
LL | bar(|arg0: usize| f(/* u64 */));
| +++++++++++++ +++++++++++

error: aborting due to 4 previous errors

Expand Down
4 changes: 4 additions & 0 deletions tests/ui/mismatched_types/closure-ref-114180.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ LL | v.sort_by(compare);
found closure signature `fn((_,), (_,)) -> _`
note: required by a bound in `slice::<impl [T]>::sort_by`
--> $SRC_DIR/alloc/src/slice.rs:LL:COL
help: consider wrapping the function in a closure
|
LL | v.sort_by(|arg0, arg1| compare(*arg0, *arg1));
| ++++++++++++ ++++++++++++++
help: consider adjusting the signature so it borrows its arguments
|
LL | let compare = |&(a,), &(e,)| todo!();
Expand Down
8 changes: 8 additions & 0 deletions tests/ui/mismatched_types/fn-variance-1.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ note: required by a bound in `apply`
|
LL | fn apply<T, F>(t: T, f: F) where F: FnOnce(T) {
| ^^^^^^^^^ required by this bound in `apply`
help: consider wrapping the function in a closure
|
LL | apply(&3, |x| takes_mut(&mut *x));
| +++ +++++++++

error[E0631]: type mismatch in function arguments
--> $DIR/fn-variance-1.rs:15:19
Expand All @@ -35,6 +39,10 @@ note: required by a bound in `apply`
|
LL | fn apply<T, F>(t: T, f: F) where F: FnOnce(T) {
| ^^^^^^^^^ required by this bound in `apply`
help: consider wrapping the function in a closure
|
LL | apply(&mut 3, |x| takes_imm(&*x));
| +++ +++++

error: aborting due to 2 previous errors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ LL | let _has_inference_vars: Option<i32> = Some(0).map(deref_int);
found function signature `for<'a> fn(&'a i32) -> _`
note: required by a bound in `Option::<T>::map`
--> $SRC_DIR/core/src/option.rs:LL:COL
help: consider wrapping the function in a closure
|
LL | let _has_inference_vars: Option<i32> = Some(0).map(|a| deref_int(&a));
| +++ ++++
help: consider adjusting the signature so it does not borrow its argument
|
LL - fn deref_int(a: &i32) -> i32 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ LL | let _ = produces_string().and_then(takes_str_but_too_many_refs);
found function signature `for<'a, 'b> fn(&'a &'b str) -> _`
note: required by a bound in `Option::<T>::and_then`
--> $SRC_DIR/core/src/option.rs:LL:COL
help: consider wrapping the function in a closure
|
LL | let _ = produces_string().and_then(|arg0: String| takes_str_but_too_many_refs(/* &&str */));
| ++++++++++++++ +++++++++++++

error[E0277]: expected a `FnOnce(String)` closure, found `for<'a> extern "C" fn(&'a str) -> Option<()> {takes_str_but_wrong_abi}`
--> $DIR/suggest-option-asderef-unfixable.rs:26:40
Expand Down Expand Up @@ -68,6 +72,10 @@ LL | let _ = Some(TypeWithoutDeref).and_then(takes_str_but_too_many_refs);
found function signature `for<'a, 'b> fn(&'a &'b str) -> _`
note: required by a bound in `Option::<T>::and_then`
--> $SRC_DIR/core/src/option.rs:LL:COL
help: consider wrapping the function in a closure
|
LL | let _ = Some(TypeWithoutDeref).and_then(|arg0: TypeWithoutDeref| takes_str_but_too_many_refs(/* &&str */));
| ++++++++++++++++++++++++ +++++++++++++

error: aborting due to 5 previous errors

Expand Down
Loading