From fe5be5965e8d193f71e474e351110debe0f690d7 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Mon, 30 Mar 2020 12:39:37 +0200 Subject: [PATCH] Avoid inference getting stuck when the expected type contains a union/intersection When we type a method call, we infer constraints based on its expected type before typing its arguments. This way, we can type these arguments with a precise expected type. This works fine as long as the constraints we infer based on the expected type are _necessary_ constraints, but in general type inference can go further and infer _sufficient_ constraints, meaning that we might get stuck with a set of constraints which does not allow the method arguments to be typed at all. Since 8067b952875426d640968be865773f6ef3783f3c we work around the problem by simply not propagating any constraint when the expected type is a union, but this solution is incomplete: - It only handles unions at the top-level, but the same problem can happen with unions in any covariant position (method b of or-inf.scala) as well as intersections in contravariant positions (and-inf.scala, i8378.scala) - Even when a union appear at the top-level, there might be constraints we can propagate, for example if only one branch can possibly match (method c of or-inf.scala) Thankfully, we already have a solution that works for all these problems: `TypeComparer#either` is capable of inferring only necessary constraints. So far, this was only done when inferring GADT bounds to preserve soundness, this commit extends this to use the same logic when constraining a method based on its expected type (as determined by the ConstrainResult mode). Additionally, `ConstraintHandling#addConstraint` needs to also be taught to only keep necessary constraints under this mode. Fixes #8378 which I previously thought was unfixable :). --- .../tools/dotc/core/ConstraintHandling.scala | 26 ++++-- compiler/src/dotty/tools/dotc/core/Mode.scala | 3 + .../dotty/tools/dotc/core/TypeComparer.scala | 85 +++++++++++-------- .../dotty/tools/dotc/typer/ProtoTypes.scala | 11 +-- .../dotty/tools/dotc/CompilationTests.scala | 1 + .../interop-polytypes.scala | 0 tests/neg/i6565.scala | 4 +- tests/neg/union.scala | 2 +- tests/pos/and-inf.scala | 13 +++ tests/pos/i7829.scala | 27 ++++++ tests/pos/i8378.scala | 17 ++++ tests/pos/or-inf.scala | 14 +++ tests/pos/orinf.scala | 6 -- 13 files changed, 151 insertions(+), 58 deletions(-) rename tests/{explicit-nulls/neg => neg-custom-args}/interop-polytypes.scala (100%) create mode 100644 tests/pos/and-inf.scala create mode 100644 tests/pos/i7829.scala create mode 100644 tests/pos/i8378.scala create mode 100644 tests/pos/or-inf.scala delete mode 100644 tests/pos/orinf.scala diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 62c60fbf93c0..0f61fd2e25fe 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -484,9 +484,10 @@ trait ConstraintHandling[AbstractContext] { * recording an isLess relationship instead (even though this is not implied * by the bound). * - * Narrowing a constraint is better than widening it, because narrowing leads - * to incompleteness (which we face anyway, see for instance eitherIsSubType) - * but widening leads to unsoundness. + * Normally, narrowing a constraint is better than widening it, because + * narrowing leads to incompleteness (which we face anyway, see for + * instance `TypeComparer#either`) but widening leads to unsoundness, + * but note the special handling in `ConstrainResult` mode below. * * A test case that demonstrates the problem is i864.scala. * Turn Config.checkConstraintsSeparated on to get an accurate diagnostic @@ -544,10 +545,23 @@ trait ConstraintHandling[AbstractContext] { case bound: TypeParamRef if constraint contains bound => addParamBound(bound) case _ => + val savedConstraint = constraint val pbound = prune(bound) - pbound.exists - && kindCompatible(param, pbound) - && (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound)) + val constraintsNarrowed = constraint ne savedConstraint + + val res = + pbound.exists + && kindCompatible(param, pbound) + && (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound)) + // If we're in `ConstrainResult` mode, we don't want to commit to a + // set of constraints that would later prevent us from typechecking + // arguments, so if `pruneParams` had to narrow the constraints, we + // simply do not record any new constraint. + // Unlike in `TypeComparer#either`, the same reasoning does not apply + // to GADT mode because this code is never run on GADT constraints. + if ctx.mode.is(Mode.ConstrainResult) && constraintsNarrowed then + constraint = savedConstraint + res } finally addConstraintInvocations -= 1 } diff --git a/compiler/src/dotty/tools/dotc/core/Mode.scala b/compiler/src/dotty/tools/dotc/core/Mode.scala index bc49bd8ec2ed..f6a6c97c25e1 100644 --- a/compiler/src/dotty/tools/dotc/core/Mode.scala +++ b/compiler/src/dotty/tools/dotc/core/Mode.scala @@ -60,6 +60,9 @@ object Mode { */ val Printing: Mode = newMode(10, "Printing") + /** We are constraining a method based on its expected type. */ + val ConstrainResult: Mode = newMode(11, "ConstrainResult") + /** We are currently in a `viewExists` check. In that case, ambiguous * implicits checks are disabled and we succeed with the first implicit * found. diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index f06928a57ce8..62fae9637151 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1364,14 +1364,26 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w /** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints. * - * If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints. - * If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all. + * If we're inferring GADT bounds or constraining a method based on its + * expected type, we infer only the _necessary_ constraints, this means we + * keep the smaller constraint if any, or no constraint at all. This is + * necessary for GADT bounds inference to be sound. When constraining a + * method, this avoid committing of constraints that would later prevent us + * from typechecking method arguments, see or-inf.scala and and-inf.scala for + * examples. * + * Otherwise, we infer _sufficient_ constraints: we try to keep the smaller of + * the two constraints, but if never is smaller than the other, we just pick + * the first one. + * + * @see [[necessaryEither]] for the GADT / result type case * @see [[sufficientEither]] for the normal case - * @see [[necessaryEither]] for the GADTFlexible case */ protected def either(op1: => Boolean, op2: => Boolean): Boolean = - if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2) + if ctx.mode.is(Mode.GadtConstraintInference) || ctx.mode.is(Mode.ConstrainResult) then + necessaryEither(op1, op2) + else + sufficientEither(op1, op2) /** Returns true iff the result of evaluating either `op1` or `op2` is true, * trying at the same time to keep the constraint as wide as possible. @@ -1438,8 +1450,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w * T1 & T2 <:< T3 * T1 <:< T2 | T3 * - * Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT - * constraints that necessarily follow from the subtyping relationship. For instance, if we have + * Unlike [[sufficientEither]], this method is used in GADTConstraintInference mode, when we are attempting + * to infer GADT constraints that necessarily follow from the subtyping relationship. For instance, if we have * * enum Expr[T] { * case IntExpr(i: Int) extends Expr[Int] @@ -1466,48 +1478,49 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w * * then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive. * + * This method is also used in ConstrainResult mode + * to avoid inference getting stuck due to lack of backtracking, + * see or-inf.scala and and-inf.scala for examples. + * * Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both * subtyping relationships. */ - private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = { + private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = val preConstraint = constraint - val preGadt = ctx.gadt.fresh - // if GADTflexible mode is on, we expect to always have a ProperGadtConstraint - val pre = preGadt.asInstanceOf[ProperGadtConstraint] - if (op1) { - val leftConstraint = constraint - val leftGadt = ctx.gadt.fresh + + def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean = + subsumes(left, right, preConstraint) && preGadt.match + case preGadt: ProperGadtConstraint => + preGadt.subsumes(leftGadt, rightGadt, preGadt) + case _ => + true + + if op1 then + val op1Constraint = constraint + val op1Gadt = ctx.gadt.fresh constraint = preConstraint ctx.gadt.restore(preGadt) - if (op2) - if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) { - gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt") - constr.println(i"CUT - prefer $constraint over $leftConstraint") - true - } - else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) { - gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}") - constr.println(i"CUT - prefer $leftConstraint over $constraint") - constraint = leftConstraint - ctx.gadt.restore(leftGadt) - true - } - else { + if op2 then + if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then + gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt") + constr.println(i"CUT - prefer $constraint over $op1Constraint") + else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then + gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}") + constr.println(i"CUT - prefer $op1Constraint over $constraint") + constraint = op1Constraint + ctx.gadt.restore(op1Gadt) + else gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt") constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint") constraint = preConstraint ctx.gadt.restore(preGadt) - true - } - else { - constraint = leftConstraint - ctx.gadt.restore(leftGadt) - true - } - } + else + constraint = op1Constraint + ctx.gadt.restore(op1Gadt) + true else op2 - } + end necessaryEither /** Does type `tp1` have a member with name `name` whose normalized type is a subtype of * the normalized type of the refinement `tp2`? diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index e99a86088a8e..418b3538d34a 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -59,17 +59,14 @@ object ProtoTypes { else ctx.test(testCompat) } - private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = pt.dealias match { - case _: OrType => true - // Don't constrain results with union types, since comparison with a union - // type on the right might commit too early into one side. - case pt => pt.isRef(defn.UnitClass) - } + private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = + pt.dealias.isRef(defn.UnitClass) /** Check that the result type of the current method * fits the given expected result type. */ - def constrainResult(mt: Type, pt: Type)(implicit ctx: Context): Boolean = { + def constrainResult(mt: Type, pt: Type)(implicit parentCtx: Context): Boolean = { + given ctx as Context = parentCtx.addMode(Mode.ConstrainResult) val savedConstraint = ctx.typerState.constraint val res = pt.widenExpr match { case pt: FunProto => diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index e8f2df1a8802..8cccc15d3e85 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -137,6 +137,7 @@ class CompilationTests extends ParallelTesting { compileFile("tests/neg-custom-args/i3882.scala", allowDeepSubtypes), compileFile("tests/neg-custom-args/i4372.scala", allowDeepSubtypes), compileFile("tests/neg-custom-args/i1754.scala", allowDeepSubtypes), + compileFile("tests/neg-custom-args/interop-polytypes.scala", allowDeepSubtypes.and("-Yexplicit-nulls")), compileFile("tests/neg-custom-args/conditionalWarnings.scala", allowDeepSubtypes.and("-deprecation").and("-Xfatal-warnings")), compileFilesInDir("tests/neg-custom-args/isInstanceOf", allowDeepSubtypes and "-Xfatal-warnings"), compileFile("tests/neg-custom-args/i3627.scala", allowDeepSubtypes), diff --git a/tests/explicit-nulls/neg/interop-polytypes.scala b/tests/neg-custom-args/interop-polytypes.scala similarity index 100% rename from tests/explicit-nulls/neg/interop-polytypes.scala rename to tests/neg-custom-args/interop-polytypes.scala diff --git a/tests/neg/i6565.scala b/tests/neg/i6565.scala index a51eeb24c308..d5fab12842d3 100644 --- a/tests/neg/i6565.scala +++ b/tests/neg/i6565.scala @@ -9,9 +9,9 @@ def (o: Lifted[O]) flatMap [O, U] (f: O => Lifted[U]): Lifted[U] = ??? val error: Err = Err() lazy val ok: Lifted[String] = { // ok despite map returning a union - point("a").map(_ => if true then "foo" else error) // error + point("a").map(_ => if true then "foo" else error) // ok } lazy val bad: Lifted[String] = { // found Lifted[Object] point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // error -} \ No newline at end of file +} diff --git a/tests/neg/union.scala b/tests/neg/union.scala index c594e83d74bc..0a702ab70058 100644 --- a/tests/neg/union.scala +++ b/tests/neg/union.scala @@ -17,7 +17,7 @@ object O { val x: A = f(new A { }, new A) - val y1: A | B = f(new A { }, new B) // error + val y1: A | B = f(new A { }, new B) // ok val y2: A | B = f[A | B](new A { }, new B) // ok val z = if (???) new A{} else new B diff --git a/tests/pos/and-inf.scala b/tests/pos/and-inf.scala new file mode 100644 index 000000000000..3008014a00a9 --- /dev/null +++ b/tests/pos/and-inf.scala @@ -0,0 +1,13 @@ +class A +class B + +class Inv[T] +class Contra[-T] + +class Test { + def foo[T, S](x: T, y: S): Contra[Inv[T] & Inv[S]] = ??? + val a: A = new A + val b: B = new B + + val x: Contra[Inv[A] & Inv[B]] = foo(a, b) +} diff --git a/tests/pos/i7829.scala b/tests/pos/i7829.scala new file mode 100644 index 000000000000..2f3d71366b7c --- /dev/null +++ b/tests/pos/i7829.scala @@ -0,0 +1,27 @@ +class X +class Y + +object Test { + type Id[T] = T + + val a: 1 = identity(1) + val b: Id[1] = identity(1) + + val c: X | Y = identity(if (true) new X else new Y) + val d: Id[X | Y] = identity(if (true) new X else new Y) + + def impUnion: Unit = { + class Base + class A extends Base + class B extends Base + class Inv[T] + + implicit def invBase: Inv[Base] = new Inv[Base] + + def getInv[T](x: T)(implicit inv: Inv[T]): Int = 1 + + val a: Int = getInv(if (true) new A else new B) + // If we keep unions when doing the implicit search, this would give us: "no implicit argument of type Inv[X | Y]" + val b: Int | Any = getInv(if (true) new A else new B) + } +} diff --git a/tests/pos/i8378.scala b/tests/pos/i8378.scala new file mode 100644 index 000000000000..b69fec928c76 --- /dev/null +++ b/tests/pos/i8378.scala @@ -0,0 +1,17 @@ +trait Has[A] + +trait A +trait B +trait C + +trait ZLayer[-RIn, +E, +ROut] + +object ZLayer { + def fromServices[A0, A1, B](f: (A0, A1) => B): ZLayer[Has[A0] with Has[A1], Nothing, Has[B]] = + ??? +} + +val live: ZLayer[Has[A] & Has[B], Nothing, Has[C]] = + ZLayer.fromServices { (a: A, b: B) => + new C {} + } diff --git a/tests/pos/or-inf.scala b/tests/pos/or-inf.scala new file mode 100644 index 000000000000..e6022b888e14 --- /dev/null +++ b/tests/pos/or-inf.scala @@ -0,0 +1,14 @@ +object Test { + + def a(lis: Set[Int] | Set[String]) = {} + a(Set(1)) + a(Set("")) + + def b(lis: List[Set[Int] | Set[String]]) = {} + b(List(Set(1))) + b(List(Set(""))) + + def c(x: Set[Any] | Array[Any]) = {} + c(Set(1)) + c(Array(1)) +} diff --git a/tests/pos/orinf.scala b/tests/pos/orinf.scala deleted file mode 100644 index 30b7fd2f6353..000000000000 --- a/tests/pos/orinf.scala +++ /dev/null @@ -1,6 +0,0 @@ -object Test { - - def foo(lis: scala.collection.immutable.Set[Int] | scala.collection.immutable.Set[String]) = lis - foo(Set(1)) - foo(Set("")) -}