Skip to content

Simplify handling of union types #2330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,21 +258,19 @@ trait ConstraintHandling {
}

// First, solve the constraint.
var inst = approximation(param, fromBelow)
var inst = approximation(param, fromBelow).simplified

// Then, approximate by (1.) - (3.) and simplify as follows.
// 1. If instance is from below and is a singleton type, yet
// upper bound is not a singleton type, widen the instance.
if (fromBelow && isSingleton(inst) && !isSingleton(upperBound))
inst = inst.widen

inst = inst.simplified

// 2. If instance is from below and is a fully-defined union type, yet upper bound
// is not a union type, approximate the union type from above by an intersection
// of all common base types.
if (fromBelow && isOrType(inst) && isFullyDefined(inst) && !isOrType(upperBound))
inst = ctx.harmonizeUnion(inst)
if (fromBelow && isOrType(inst) && !isOrType(upperBound))
inst = inst.widenUnion

inst
}
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ class Definitions {
enterCompleteClassSymbol(
ScalaPackageClass, tpnme.Singleton, PureInterfaceCreationFlags | Final,
List(AnyClass.typeRef), EmptyScope)
def SingletonType = SingletonClass.typeRef

lazy val SeqType: TypeRef = ctx.requiredClassRef("scala.collection.Seq")
def SeqClass(implicit ctx: Context) = SeqType.symbol.asClass
Expand Down
11 changes: 9 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
else thirdTry(tp1, tp2)
case tp1 @ OrType(tp11, tp12) =>
def joinOK = tp2.dealias match {
case tp12: HKApply =>
case _: HKApply =>
// If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a
// type parameter, we will instantiate `C` to `A` and then fail when comparing
// with `B[Y]`. To do the right thing, we need to instantiate `C` to the
Expand Down Expand Up @@ -1511,10 +1511,17 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {

override def compareHkApply2(tp1: Type, tp2: HKApply, tycon2: Type, args2: List[Type]): Boolean = {
def addendum = ""
traceIndented(i"compareHkApply $tp1, $tp2$addendum") {
traceIndented(i"compareHkApply2 $tp1, $tp2$addendum") {
super.compareHkApply2(tp1, tp2, tycon2, args2)
}
}

override def compareHkApply1(tp1: HKApply, tycon1: Type, args1: List[Type], tp2: Type): Boolean = {
def addendum = ""
traceIndented(i"compareHkApply1 $tp1, $tp2$addendum") {
super.compareHkApply1(tp1, tycon1, args1, tp2)
}
}

override def toString = "Subtype trace:" + { try b.toString finally b.clear() }
}
31 changes: 0 additions & 31 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,37 +273,6 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
}
}

/** Given a disjunction T1 | ... | Tn of types with potentially embedded
* type variables, constrain type variables further if this eliminates
* some of the branches of the disjunction. Do this also for disjunctions
* embedded in intersections, as parents in refinements, and in recursive types.
*
* For instance, if `A` is an unconstrained type variable, then
*
* ArrayBuffer[Int] | ArrayBuffer[A]
*
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
* instead of `ArrayBuffer[_ >: Int | A <: Int & A]`
*/
def harmonizeUnion(tp: Type): Type = tp match {
case tp: OrType =>
joinIfScala2(ctx.typeComparer.lub(harmonizeUnion(tp.tp1), harmonizeUnion(tp.tp2), canConstrain = true))
case tp @ AndType(tp1, tp2) =>
tp derived_& (harmonizeUnion(tp1), harmonizeUnion(tp2))
case tp: RefinedType =>
tp.derivedRefinedType(harmonizeUnion(tp.parent), tp.refinedName, tp.refinedInfo)
case tp: RecType =>
tp.rebind(harmonizeUnion(tp.parent))
case _ =>
tp
}

/** Under -language:Scala2: Replace or-types with their joins */
private def joinIfScala2(tp: Type) = tp match {
case tp: OrType if scala2Mode => tp.join
case _ => tp
}

/** Not currently needed:
*
def liftToRec(f: (Type, Type) => Type)(tp1: Type, tp2: Type)(implicit ctx: Context) = {
Expand Down
37 changes: 33 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -830,23 +830,23 @@ object Types {
* def o: Outer
* <o.x.type>.widen = o.C
*/
@tailrec final def widen(implicit ctx: Context): Type = widenSingleton match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was the tailrec annotation removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was noise. I was against adding it, but it got merged before I could comment.

As I wrote then: Add a @tailrec only if it is important that the method is in fact tail-recursive. Either it is super-performance critical or can potentially recurse deeply. Do not add @tailrec just because the method happens to be tail recursive.

final def widen(implicit ctx: Context): Type = widenSingleton match {
case tp: ExprType => tp.resultType.widen
case tp => tp
}

/** Widen from singleton type to its underlying non-singleton
* base type by applying one or more `underlying` dereferences.
*/
@tailrec final def widenSingleton(implicit ctx: Context): Type = stripTypeVar match {
final def widenSingleton(implicit ctx: Context): Type = stripTypeVar match {
case tp: SingletonType if !tp.isOverloaded => tp.underlying.widenSingleton
case _ => this
}

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping Expr types.
*/
@tailrec final def widenTermRefExpr(implicit ctx: Context): Type = stripTypeVar match {
final def widenTermRefExpr(implicit ctx: Context): Type = stripTypeVar match {
case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenTermRefExpr
case _ => this
}
Expand All @@ -860,7 +860,7 @@ object Types {
}

/** Widen type if it is unstable (i.e. an ExprType, or TermRef to unstable symbol */
@tailrec final def widenIfUnstable(implicit ctx: Context): Type = stripTypeVar match {
final def widenIfUnstable(implicit ctx: Context): Type = stripTypeVar match {
case tp: ExprType => tp.resultType.widenIfUnstable
case tp: TermRef if !tp.symbol.isStable => tp.underlying.widenIfUnstable
case _ => this
Expand All @@ -872,6 +872,35 @@ object Types {
case _ => this
}

/** If this type contains embedded union types, replace them by their joins.
* "Embedded" means: inside intersectons or recursive types, or in prefixes of refined types.
* If an embedded union is found, we first try to simplify or eliminate it by
* re-lubbing it while allowing type parameters to be constrained further.
* Any remaining union types are replaced by their joins.
*
* For instance, if `A` is an unconstrained type variable, then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect indentation around here.

*
* ArrayBuffer[Int] | ArrayBuffer[A]
*
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
* instead of `ArrayBuffer[_ >: Int | A <: Int & A]`
*/
def widenUnion(implicit ctx: Context): Type = this match {
case OrType(tp1, tp2) =>
ctx.typeComparer.lub(tp1.widenUnion, tp2.widenUnion, canConstrain = true) match {
case union: OrType => union.join
case res => res
}
case tp @ AndType(tp1, tp2) =>
tp derived_& (tp1.widenUnion, tp2.widenUnion)
case tp: RefinedType =>
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
case tp: RecType =>
tp.rebind(tp.parent.widenUnion)
case _ =>
this
}

/** Eliminate anonymous classes */
final def deAnonymize(implicit ctx: Context): Type = this match {
case tp:TypeRef if tp.symbol.isAnonymousClass =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1034,13 +1034,13 @@ class Namer { typer: Typer =>
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
def isInline = sym.is(FinalOrInline, butNot = Method | Mutable)

// Widen rhs type and approximate `|' but keep ConstantTypes if
// Widen rhs type and eliminate `|' but keep ConstantTypes if
// definition is inline (i.e. final in Scala2) and keep module singleton types
// instead of widening to the underlying module class types.
def widenRhs(tp: Type): Type = tp.widenTermRefExpr match {
case ctp: ConstantType if isInline => ctp
case ref: TypeRef if ref.symbol.is(ModuleClass) => tp
case _ => ctx.harmonizeUnion(tp.widen)
case _ => tp.widen.widenUnion
}

// Replace aliases to Unit by Unit itself. If we leave the alias in
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ object ProtoTypes {
/** Create a new TypeVar that represents a dependent method parameter singleton */
def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar = {
val poly = PolyType(DepParamName.fresh().toTypeName :: Nil)(
pt => TypeBounds.upper(AndType(tp, defn.SingletonType)) :: Nil,
pt => TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)) :: Nil,
pt => defn.AnyType)
constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true)
._2.head.tpe.asInstanceOf[TypeVar]
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotc/tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class tests extends CompilerTest {
@Test def rewrites = compileFile(posScala2Dir, "rewrites", "-rewrite" :: scala2mode)

@Test def pos_t8146a = compileFile(posSpecialDir, "t8146a")(allowDeepSubtypes)
@Test def pos_jon = compileFile(posSpecialDir, "jon")(allowDeepSubtypes)
Copy link
Member

@smarter smarter Apr 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it surprising that this requires deep subtypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not look too deeply. But SeqFactory has a funky recursive type

SeqFactory[CC[X] <: Seq[X] with GenericTraversableTemplate[X, CC]] extends GenSeqFactory[CC] with TraversableFactory[CC]

so it's not so implausible.


@Test def pos_t5545 = {
// compile by hand in two batches, since junit lacks the infrastructure to
Expand Down
28 changes: 28 additions & 0 deletions tests/neg/union.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
object Test {

class A
class B extends A
class C extends A
class D extends A

val b = true
val x = if (b) new B else new C
val y: B | C = x // error
}

object O {
class A
class B
def f[T](x: T, y: T): T = x

val x: A = f(new A { }, new A)

val y1: A | B = f(new A { }, new B) // error
val y2: A | B = f[A | B](new A { }, new B) // ok

val z = if (???) new A{} else new B

val z1: A | B = z // error

val z2: A | B = if (???) new A else new B // ok
}
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/pos/anonClassSubtyping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ object O {

val x: A = f(new A { }, new A)

val y: A | B = f(new A { }, new B)
val z: A | B = if (???) new A{} else new A
}
2 changes: 1 addition & 1 deletion tests/pos/constraining-lub.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ object Test {

val x: Inv[Int] = inv(true)

def inv2(cond: Boolean) =
def inv2(cond: Boolean): Inv[Int] | Inv2[Int] =
if (cond) {
if (cond)
new Inv(1)
Expand Down
4 changes: 3 additions & 1 deletion tests/pos/intersection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ object intersection {
val z = if (???) x else y

val a: A & B => Unit = z
val b: (A => Unit) | (B => Unit) = z
//val b: (A => Unit) | (B => Unit) = z // error under new or-type rules

val c: (A => Unit) | (B => Unit) = if (???) x else y // ok

type needsA = A => Nothing
type needsB = B => Nothing
Expand Down
11 changes: 0 additions & 11 deletions tests/pos/union.scala

This file was deleted.