Skip to content

Fix #6047: Implement variance rules for match types #6050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ class Definitions {

lazy val TypeBox_CAP: TypeSymbol = TypeBoxType.symbol.requiredType(tpnme.CAP)

lazy val MatchCaseType: TypeRef = ctx.requiredClassRef("scala.internal.MatchCase")
def MatchCaseClass(implicit ctx: Context): ClassSymbol = MatchCaseType.symbol.asClass

lazy val NotType: TypeRef = ctx.requiredClassRef("scala.implicits.Not")
def NotClass(implicit ctx: Context): ClassSymbol = NotType.symbol.asClass
def NotModule(implicit ctx: Context): Symbol = NotClass.companionModule
Expand Down Expand Up @@ -931,6 +934,23 @@ class Definitions {
}
}

object MatchCase {
def apply(pat: Type, body: Type)(implicit ctx: Context): Type =
MatchCaseType.appliedTo(pat, body)
def unapply(tp: Type)(implicit ctx: Context): Option[(Type, Type)] = tp match {
case AppliedType(tycon, pat :: body :: Nil) if tycon.isRef(MatchCaseClass) =>
Some((pat, body))
case _ =>
None
}
def isInstance(tp: Type)(implicit ctx: Context): Boolean = tp match {
case AppliedType(tycon: TypeRef, _) =>
tycon.name == tpnme.MatchCase && // necessary pre-filter to avoid forcing symbols
tycon.isRef(MatchCaseClass)
case _ => false
}
}

/** An extractor for multi-dimensional arrays.
* Note that this will also extract the high bound if an
* element type is a wildcard. E.g.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ object StdNames {
val Literal: N = "Literal"
val LiteralAnnotArg: N = "LiteralAnnotArg"
val longHash: N = "longHash"
val MatchCase: N = "MatchCase"
val Modifiers: N = "Modifiers"
val NestedAnnotArg: N = "NestedAnnotArg"
val NoFlags: N = "NoFlags"
Expand Down
71 changes: 36 additions & 35 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2134,45 +2134,46 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
}
}

var result: Type = NoType
var remainingCases = cases
while (!remainingCases.isEmpty) {
val (cas :: cass) = remainingCases
remainingCases = cass
val saved = constraint
try {
inFrozenConstraint {
val cas1 = cas match {
case cas: HKTypeLambda =>
caseLambda = constrained(cas)
caseLambda.resultType
/** Match a single case.
* @return Some(tp) if the match succeeds with type `tp`
* Some(NoType) if the match fails, and there is an overlap between pattern and scrutinee
* None if the match fails and we should consider the following cases
* because scrutinee and pattern do not overlap
*/
def matchCase(cas: Type): Option[Type] = {
val cas1 = cas match {
case cas: HKTypeLambda =>
caseLambda = constrained(cas)
caseLambda.resultType
case _ =>
cas
}
val defn.MatchCase(pat, body) = cas1
if (isSubType(scrut, pat))
// `scrut` is a subtype of `pat`: *It's a Match!*
Some {
caseLambda match {
case caseLambda: HKTypeLambda =>
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
instantiateParams(instances)(body)
case _ =>
cas
}
val defn.FunctionOf(pat :: Nil, body, _, _) = cas1
if (isSubType(scrut, pat)) {
// `scrut` is a subtype of `pat`: *It's a Match!*
result = caseLambda match {
case caseLambda: HKTypeLambda =>
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
instantiateParams(instances)(body)
case _ =>
body
}
remainingCases = Nil
} else if (!intersecting(scrut, pat)) {
// We found a proof that `scrut` and `pat` are incompatible.
// The search continues.
} else {
// We are stuck: this match type instanciation is irreducible.
result = NoType
remainingCases = Nil
body
}
}
}
finally constraint = saved
else if (intersecting(scrut, pat))
Some(NoType)
else
// We found a proof that `scrut` and `pat` are incompatible.
// The search continues.
None
}
result

def recur(cases: List[Type]): Type = cases match {
case cas :: cases1 => matchCase(cas).getOrElse(recur(cases1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the inFrozenConstraint be pushed inside? The way this is written, it's possible that matchCase(case1) returns None while still infer new constrains, which will then be visible when calling matchCase(case2) to matchCase(caseN).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how that could happen. Once the constraint is frozen is stays so (in the same typeComparer). Which constraints would be computed by matchCase(case1) and propagated to matchCase(case2)?

case Nil => NoType
}

inFrozenConstraint(recur(cases))
}
}

Expand Down
26 changes: 17 additions & 9 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3756,7 +3756,7 @@ object Types {

def caseType(tp: Type)(implicit ctx: Context): Type = tp match {
case tp: HKTypeLambda => caseType(tp.resType)
case defn.FunctionOf(_, restpe, _, _) => restpe
case defn.MatchCase(_, body) => body
}

def alternatives(implicit ctx: Context): List[Type] = cases.map(caseType)
Expand Down Expand Up @@ -4417,10 +4417,12 @@ object Types {

case tp: LambdaType =>
def mapOverLambda = {
variance = -variance
val restpe = tp.resultType
val saved = variance
variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having to do this kind of distinction in TypeMap makes me wonder if reusing LambdaType-s to encode cases of MatchType is the correct approach. Maybe MatchCase should simply be its own independent type...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of machinery associated with LambdaType which would have to be duplicated. So I think what we have is the lesser evil.

val ptypes1 = tp.paramInfos.mapConserve(this).asInstanceOf[List[tp.PInfo]]
variance = -variance
derivedLambdaType(tp)(ptypes1, this(tp.resultType))
variance = saved
derivedLambdaType(tp)(ptypes1, this(restpe))
}
mapOverLambda

Expand All @@ -4440,7 +4442,9 @@ object Types {
derivedOrType(tp, this(tp.tp1), this(tp.tp2))

case tp: MatchType =>
derivedMatchType(tp, this(tp.bound), this(tp.scrutinee), tp.cases.mapConserve(this))
val bound1 = this(tp.bound)
val scrut1 = atVariance(0)(this(tp.scrutinee))
derivedMatchType(tp, bound1, scrut1, tp.cases.mapConserve(this))

case tp: SkolemType =>
derivedSkolemType(tp, this(tp.info))
Expand Down Expand Up @@ -4804,10 +4808,12 @@ object Types {
case _: BoundType | _: ThisType => x

case tp: LambdaType =>
variance = -variance
val restpe = tp.resultType
val saved = variance
variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance
val y = foldOver(x, tp.paramInfos)
variance = -variance
this(y, tp.resultType)
variance = saved
this(y, restpe)

case tp: TermRef =>
if (stopAtStatic && tp.currentSymbol.isStatic || (tp.prefix `eq` NoPrefix)) x
Expand Down Expand Up @@ -4835,7 +4841,9 @@ object Types {
this(this(x, tp.tp1), tp.tp2)

case tp: MatchType =>
foldOver(this(this(x, tp.bound), tp.scrutinee), tp.cases)
val x1 = this(x, tp.bound)
val x2 = atVariance(0)(this(x1, tp.scrutinee))
foldOver(x2, tp.cases)

case AnnotatedType(underlying, annot) =>
this(applyToAnnot(x, annot), underlying)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/printing/Formatting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Formatting {
case arg: Showable =>
try arg.show
catch {
case ex: CyclicReference => "... (caught cyclic reference) ..."
case NonFatal(ex)
if !ctx.mode.is(Mode.PrintShowExceptions) &&
!ctx.settings.YshowPrintErrors.value =>
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
changePrec(OrTypePrec) { toText(tp1) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(tp2) } }
case MatchType(bound, scrutinee, cases) =>
changePrec(GlobalPrec) {
def caseText(tp: Type): Text = "case " ~ toText(tp)
def caseText(tp: Type): Text = tp match {
case defn.MatchCase(pat, body) => "case " ~ toText(pat) ~ " => " ~ toText(body)
case _ => "case " ~ toText(tp)
}
def casesText = Text(cases.map(caseText), "\n")
atPrec(InfixPrec) { toText(scrutinee) } ~
keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~
Expand Down
29 changes: 21 additions & 8 deletions compiler/src/dotty/tools/dotc/reporting/Reporter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import core.Mode
import dotty.tools.dotc.core.Symbols.{Symbol, NoSymbol}
import diagnostic.messages._
import diagnostic._
import ast.{tpd, Trees}
import Message._

object Reporter {
Expand Down Expand Up @@ -89,21 +90,25 @@ trait Reporting { this: Context =>
}

def warning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
reportWarning(new Warning(msg, pos))
reportWarning(new Warning(msg, addInlineds(pos)))

def strictWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
if (this.settings.strict.value) error(msg, pos)
else reportWarning(new ExtendMessage(() => msg)(_ + "\n(This would be an error under strict mode)").warning(pos))
def strictWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit = {
val fullPos = addInlineds(pos)
if (this.settings.strict.value) error(msg, fullPos)
else reportWarning(new ExtendMessage(() => msg)(_ + "\n(This would be an error under strict mode)").warning(fullPos))
}

def error(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
reporter.report(new Error(msg, pos))
reporter.report(new Error(msg, addInlineds(pos)))

def errorOrMigrationWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
if (ctx.scala2Mode) migrationWarning(msg, pos) else error(msg, pos)
def errorOrMigrationWarning(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit = {
val fullPos = addInlineds(pos)
if (ctx.scala2Mode) migrationWarning(msg, fullPos) else error(msg, fullPos)
}

def restrictionError(msg: => Message, pos: SourcePosition = NoSourcePosition): Unit =
reporter.report {
new ExtendMessage(() => msg)(m => s"Implementation restriction: $m").error(pos)
new ExtendMessage(() => msg)(m => s"Implementation restriction: $m").error(addInlineds(pos))
}

def incompleteInputError(msg: => Message, pos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Unit =
Expand Down Expand Up @@ -135,6 +140,14 @@ trait Reporting { this: Context =>

def debugwarn(msg: => String, pos: SourcePosition = NoSourcePosition): Unit =
if (this.settings.Ydebug.value) warning(msg, pos)

private def addInlineds(pos: SourcePosition)(implicit ctx: Context) = {
def recur(pos: SourcePosition, inlineds: List[Trees.Tree[_]]): SourcePosition = inlineds match {
case inlined :: inlineds1 => pos.withOuter(recur(inlined.sourcePos, inlineds1))
case Nil => pos
}
recur(pos, tpd.enclosingInlineds)
}
}

/**
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ trait TypeAssigner {
}
HKTypeLambda.fromParams(
params(new mutable.ListBuffer[TypeSymbol](), pat).toList,
defn.FunctionOf(pat.tpe :: Nil, body.tpe))
defn.MatchCase(pat.tpe, body.tpe))
}
else body.tpe
tree.withType(ownType)
Expand Down
18 changes: 11 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object VarianceChecker {
val paramVarianceStr = if (v == 0) "contra" else "co"
val occursStr = variance match {
case -1 => "contra"
case 0 => "non"
case 0 => "in"
case 1 => "co"
}
val pos = tree.tparams
Expand Down Expand Up @@ -123,18 +123,19 @@ class VarianceChecker()(implicit ctx: Context) {
def apply(status: Option[VarianceError], tp: Type): Option[VarianceError] = trace(s"variance checking $tp of $base at $variance", variances) {
try
if (status.isDefined) status
else tp match {
else tp.normalized match {
case tp: TypeRef =>
val sym = tp.symbol
if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym)
else if (sym.isAliasType) this(status, sym.info.bounds.hi)
else foldOver(status, tp)
else sym.info match {
case MatchAlias(_) => foldOver(status, tp)
case TypeAlias(alias) => this(status, alias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not very clear to me than the new code for TypeAlias-es is equivalent to the previous one. Is .bounds.hi and .alias the same thing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's the same thing. The previous code was a bit outdated.

case _ => foldOver(status, tp)
}
case tp: MethodOrPoly =>
this(status, tp.resultType) // params will be checked in their TypeDef or ValDef nodes.
case AnnotatedType(_, annot) if annot.symbol == defn.UncheckedVarianceAnnot =>
status
case tp: MatchType =>
apply(status, tp.bound)
case tp: ClassInfo =>
foldOver(status, tp.classParents)
case _ =>
Expand Down Expand Up @@ -179,7 +180,7 @@ class VarianceChecker()(implicit ctx: Context) {
sym.is(PrivateLocal) ||
sym.name.is(InlineAccessorName) || // TODO: should we exclude all synthetic members?
sym.is(TypeParam) && sym.owner.isClass // already taken care of in primary constructor of class
tree match {
try tree match {
case defn: MemberDef if skip =>
ctx.debuglog(s"Skipping variance check of ${sym.showDcl}")
case tree: TypeDef =>
Expand All @@ -196,6 +197,9 @@ class VarianceChecker()(implicit ctx: Context) {
vparamss foreach (_ foreach traverse)
case _ =>
}
catch {
case ex: TypeError => ctx.error(ex.toMessage, tree.sourcePos.focus)
}
}
}
}
2 changes: 1 addition & 1 deletion compiler/test-resources/repl/i5218
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
scala> val tuple = (1, "2", 3L)
val tuple: (Int, String, Long) = (1,2,3)
scala> 0.0 *: tuple
val res0: Double *: (Int, String, Long)(tuple) = (0.0,1,2,3)
val res0: (Double, Int, String, Long) = (0.0,1,2,3)
scala> tuple ++ tuple
val res1: Int *: String *: Long *:
scala.Tuple.Concat[Unit, (Int, String, Long)(tuple)] = (1,2,3,1,2,3)
1 change: 1 addition & 0 deletions compiler/test/dotc/run-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ t3452g
t7374
tuples1.scala
tuples1a.scala
typeclass-derivation1.scala
typeclass-derivation2.scala
typeclass-derivation2a.scala
typeclass-derivation3.scala
Expand Down
Loading