Skip to content

Commit 45f441a

Browse files
committed
nll: correctly deal with bivariance
1 parent ca92d90 commit 45f441a

File tree

5 files changed

+82
-55
lines changed

5 files changed

+82
-55
lines changed

compiler/rustc_const_eval/src/interpret/eval_context.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use super::{
2323
MemPlaceMeta, Memory, MemoryKind, Operand, Place, PlaceTy, PointerArithmetic, Provenance,
2424
Scalar, StackPopJump,
2525
};
26-
use crate::transform::validate::equal_up_to_regions;
26+
use crate::transform::validate;
2727

2828
pub struct InterpCx<'mir, 'tcx, M: Machine<'mir, 'tcx>> {
2929
/// Stores the `Machine` instance.
@@ -354,8 +354,8 @@ pub(super) fn mir_assign_valid_types<'tcx>(
354354
// Type-changing assignments can happen when subtyping is used. While
355355
// all normal lifetimes are erased, higher-ranked types with their
356356
// late-bound lifetimes are still around and can lead to type
357-
// differences. So we compare ignoring lifetimes.
358-
if equal_up_to_regions(tcx, param_env, src.ty, dest.ty) {
357+
// differences.
358+
if validate::is_subtype(tcx, param_env, src.ty, dest.ty) {
359359
// Make sure the layout is equal, too -- just to be safe. Miri really
360360
// needs layout equality. For performance reason we skip this check when
361361
// the types are equal. Equal types *can* have different layouts when

compiler/rustc_const_eval/src/transform/validate.rs

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
33
use rustc_data_structures::fx::FxHashSet;
44
use rustc_index::bit_set::BitSet;
5-
use rustc_infer::infer::TyCtxtInferExt;
5+
use rustc_infer::infer::{DefiningAnchor, TyCtxtInferExt};
6+
use rustc_infer::traits::ObligationCause;
67
use rustc_middle::mir::interpret::Scalar;
78
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
89
use rustc_middle::mir::visit::{PlaceContext, Visitor};
@@ -12,12 +13,12 @@ use rustc_middle::mir::{
1213
ProjectionElem, RuntimePhase, Rvalue, SourceScope, Statement, StatementKind, Terminator,
1314
TerminatorKind, UnOp, START_BLOCK,
1415
};
15-
use rustc_middle::ty::fold::BottomUpFolder;
16-
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeFoldable, TypeVisitable};
16+
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeVisitable};
1717
use rustc_mir_dataflow::impls::MaybeStorageLive;
1818
use rustc_mir_dataflow::storage::always_storage_live_locals;
1919
use rustc_mir_dataflow::{Analysis, ResultsCursor};
2020
use rustc_target::abi::{Size, VariantIdx};
21+
use rustc_trait_selection::traits::ObligationCtxt;
2122

2223
#[derive(Copy, Clone, Debug)]
2324
enum EdgeKind {
@@ -70,13 +71,11 @@ impl<'tcx> MirPass<'tcx> for Validator {
7071
}
7172
}
7273

73-
/// Returns whether the two types are equal up to lifetimes.
74-
/// All lifetimes, including higher-ranked ones, get ignored for this comparison.
75-
/// (This is unlike the `erasing_regions` methods, which keep higher-ranked lifetimes for soundness reasons.)
74+
/// Returns whether the two types are equal up to subtyping.
7675
///
77-
/// The point of this function is to approximate "equal up to subtyping". However,
78-
/// the approximation is incorrect as variance is ignored.
79-
pub fn equal_up_to_regions<'tcx>(
76+
/// This is used in case we don't know the expected subtyping direction
77+
/// and still want to check whether anything is broken.
78+
pub fn is_equal_up_to_subtyping<'tcx>(
8079
tcx: TyCtxt<'tcx>,
8180
param_env: ParamEnv<'tcx>,
8281
src: Ty<'tcx>,
@@ -87,27 +86,40 @@ pub fn equal_up_to_regions<'tcx>(
8786
return true;
8887
}
8988

90-
// Normalize lifetimes away on both sides, then compare.
91-
let normalize = |ty: Ty<'tcx>| {
92-
tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty).fold_with(
93-
&mut BottomUpFolder {
94-
tcx,
95-
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
96-
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
97-
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
98-
// since one may have an `impl SomeTrait for fn(&32)` and
99-
// `impl SomeTrait for fn(&'static u32)` at the same time which
100-
// specify distinct values for Assoc. (See also #56105)
101-
lt_op: |_| tcx.lifetimes.re_erased,
102-
// Leave consts and types unchanged.
103-
ct_op: |ct| ct,
104-
ty_op: |ty| ty,
105-
},
106-
)
107-
};
108-
tcx.infer_ctxt().build().can_eq(param_env, normalize(src), normalize(dest)).is_ok()
89+
// Check for subtyping in either direction.
90+
is_subtype(tcx, param_env, src, dest) || is_subtype(tcx, param_env, dest, src)
10991
}
11092

93+
pub fn is_subtype<'tcx>(
94+
tcx: TyCtxt<'tcx>,
95+
param_env: ParamEnv<'tcx>,
96+
src: Ty<'tcx>,
97+
dest: Ty<'tcx>,
98+
) -> bool {
99+
if src == dest {
100+
return true;
101+
}
102+
103+
let mut builder =
104+
tcx.infer_ctxt().ignoring_regions().with_opaque_type_inference(DefiningAnchor::Bubble);
105+
let infcx = builder.build();
106+
let ocx = ObligationCtxt::new(&infcx);
107+
let cause = ObligationCause::dummy();
108+
let src = ocx.normalize(cause.clone(), param_env, src);
109+
let dest = ocx.normalize(cause.clone(), param_env, dest);
110+
let Ok(infer_ok) = infcx.at(&cause, param_env).sub(src, dest) else {
111+
return false;
112+
};
113+
let () = ocx.register_infer_ok_obligations(infer_ok);
114+
let errors = ocx.select_all_or_error();
115+
// With `Reveal::All`, opaque types get normalized away, with `Reveal::UserFacing`
116+
// we would get unification errors because we're unable to look into opaque types,
117+
// even if they're constrained in our current function.
118+
//
119+
// It seems very unlikely that this hides any bugs.
120+
let _ = infcx.inner.borrow_mut().opaque_type_storage.take_opaque_types();
121+
errors.is_empty()
122+
}
111123
struct TypeChecker<'a, 'tcx> {
112124
when: &'a str,
113125
body: &'a Body<'tcx>,
@@ -183,22 +195,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
183195
return true;
184196
}
185197

186-
// Normalize projections and things like that.
187-
// Type-changing assignments can happen when subtyping is used. While
188-
// all normal lifetimes are erased, higher-ranked types with their
189-
// late-bound lifetimes are still around and can lead to type
190-
// differences. So we compare ignoring lifetimes.
191-
192-
// First, try with reveal_all. This might not work in some cases, as the predicates
193-
// can be cleared in reveal_all mode. We try the reveal first anyways as it is used
194-
// by some other passes like inlining as well.
195-
let param_env = self.param_env.with_reveal_all_normalized(self.tcx);
196-
if equal_up_to_regions(self.tcx, param_env, src, dest) {
197-
return true;
198-
}
199-
200-
// If this fails, we can try it without the reveal.
201-
equal_up_to_regions(self.tcx, self.param_env, src, dest)
198+
is_subtype(self.tcx, self.param_env, src, dest)
202199
}
203200
}
204201

compiler/rustc_infer/src/infer/nll_relate/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,12 @@ where
556556
self.ambient_variance_info = self.ambient_variance_info.xform(info);
557557

558558
debug!(?self.ambient_variance);
559-
560-
let r = self.relate(a, b)?;
559+
// In a bivariant context this always succeeds.
560+
let r = if self.ambient_variance == ty::Variance::Bivariant {
561+
a
562+
} else {
563+
self.relate(a, b)?
564+
};
561565

562566
self.ambient_variance = old_ambient_variance;
563567

compiler/rustc_mir_transform/src/inline.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
//! Inlining pass for MIR functions
22
use crate::deref_separator::deref_finder;
33
use rustc_attr::InlineAttr;
4-
use rustc_const_eval::transform::validate::equal_up_to_regions;
54
use rustc_index::bit_set::BitSet;
65
use rustc_index::vec::Idx;
76
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
@@ -14,7 +13,8 @@ use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
1413
use rustc_target::abi::VariantIdx;
1514
use rustc_target::spec::abi::Abi;
1615

17-
use super::simplify::{remove_dead_blocks, CfgSimplifier};
16+
use crate::simplify::{remove_dead_blocks, CfgSimplifier};
17+
use crate::validate;
1818
use crate::MirPass;
1919
use std::iter;
2020
use std::ops::{Range, RangeFrom};
@@ -180,7 +180,7 @@ impl<'tcx> Inliner<'tcx> {
180180
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
181181
let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty;
182182
let output_type = callee_body.return_ty();
183-
if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) {
183+
if !validate::is_subtype(self.tcx, self.param_env, output_type, destination_ty) {
184184
trace!(?output_type, ?destination_ty);
185185
return Err("failed to normalize return type");
186186
}
@@ -200,7 +200,7 @@ impl<'tcx> Inliner<'tcx> {
200200
arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args))
201201
{
202202
let input_type = callee_body.local_decls[input].ty;
203-
if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) {
203+
if !validate::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
204204
trace!(?arg_ty, ?input_type);
205205
return Err("failed to normalize tuple argument type");
206206
}
@@ -209,7 +209,7 @@ impl<'tcx> Inliner<'tcx> {
209209
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
210210
let input_type = callee_body.local_decls[input].ty;
211211
let arg_ty = arg.ty(&caller_body.local_decls, self.tcx);
212-
if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) {
212+
if !validate::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
213213
trace!(?arg_ty, ?input_type);
214214
return Err("failed to normalize argument type");
215215
}
@@ -847,7 +847,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
847847
let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
848848
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
849849
let check_equal = |this: &mut Self, f_ty| {
850-
if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
850+
if !validate::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) {
851851
trace!(?ty, ?f_ty);
852852
this.validation = Err("failed to normalize projection type");
853853
return;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// check-pass
2+
// compile-flags: -Zvalidate-mir
3+
4+
// This test checks that bivariant parameters are handled correctly
5+
// in the mir.
6+
#![allow(coherence_leak_check)]
7+
trait Trait {
8+
type Assoc;
9+
}
10+
11+
struct Foo<T, U>(T)
12+
where
13+
T: Trait<Assoc = U>;
14+
15+
impl Trait for for<'a> fn(&'a ()) {
16+
type Assoc = u32;
17+
}
18+
impl Trait for fn(&'static ()) {
19+
type Assoc = String;
20+
}
21+
22+
fn foo(x: Foo<for<'a> fn(&'a ()), u32>) -> Foo<fn(&'static ()), String> {
23+
x
24+
}
25+
26+
fn main() {}

0 commit comments

Comments
 (0)