Skip to content

Lambda invariant check #6767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ object Phases {
private[this] var myErasedTypes = false
private[this] var myFlatClasses = false
private[this] var myRefChecked = false
private[this] var myLambdaLifted = false

private[this] var mySameMembersStartId = NoPhaseId
private[this] var mySameParentsStartId = NoPhaseId
Expand All @@ -371,6 +372,7 @@ object Phases {
final def erasedTypes: Boolean = myErasedTypes // Phase is after erasure
final def flatClasses: Boolean = myFlatClasses // Phase is after flatten
final def refChecked: Boolean = myRefChecked // Phase is after RefChecks
final def lambdaLifted: Boolean = myLambdaLifted // Phase is after LambdaLift

final def sameMembersStartId: Int = mySameMembersStartId
// id of first phase where all symbols are guaranteed to have the same members as in this phase
Expand All @@ -385,9 +387,10 @@ object Phases {
assert(start <= Periods.MaxPossiblePhaseId, s"Too many phases, Period bits overflow")
myBase = base
myPeriod = Period(NoRunId, start, end)
myErasedTypes = prev.getClass == classOf[Erasure] || prev.erasedTypes
myFlatClasses = prev.getClass == classOf[Flatten] || prev.flatClasses
myRefChecked = prev.getClass == classOf[RefChecks] || prev.refChecked
myErasedTypes = prev.getClass == classOf[Erasure] || prev.erasedTypes
myFlatClasses = prev.getClass == classOf[Flatten] || prev.flatClasses
myRefChecked = prev.getClass == classOf[RefChecks] || prev.refChecked
myLambdaLifted = prev.getClass == classOf[LambdaLift] || prev.lambdaLifted
mySameMembersStartId = if (changesMembers) id else prev.sameMembersStartId
mySameParentsStartId = if (changesParents) id else prev.sameParentsStartId
mySameBaseTypesStartId = if (changesBaseTypes) id else prev.sameBaseTypesStartId
Expand Down
26 changes: 25 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ class TreeChecker extends Phase with SymTransformer {
res
}

// used to check invariant of lambda encoding
var nestingBlock: untpd.Block | Null = null
private def withBlock[T](block: untpd.Block)(op: => T): T = {
val outerBlock = nestingBlock
nestingBlock = block
val res = op
nestingBlock = outerBlock
res
}

def assertDefined(tree: untpd.Tree)(implicit ctx: Context): Unit =
if (tree.symbol.maybeOwner.isTerm)
assert(nowDefinedSyms contains tree.symbol, i"undefined symbol ${tree.symbol} at line " + tree.sourcePos.line)
Expand Down Expand Up @@ -407,8 +417,22 @@ class TreeChecker extends Phase with SymTransformer {
}
}

override def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context): Tree = {
if (!ctx.phase.lambdaLifted) nestingBlock match {
case block @ Block((meth : DefDef) :: Nil, closure: Closure) =>
assert(meth.symbol == closure.meth.symbol, "closure.meth symbol not equal to method symbol. Block: " + block.show)

case block: untpd.Block =>
assert(false, "function literal are not properly formed as a block of DefDef and Closure. Found: " + tree.show + " Nesting block: " + block.show)

case null =>
assert(false, "function literal are not properly formed as a block of DefDef and Closure. Found: " + tree.show + " Nesting block: null")
}
super.typedClosure(tree, pt)
}

override def typedBlock(tree: untpd.Block, pt: Type)(implicit ctx: Context): Tree =
withDefinedSyms(tree.stats) { super.typedBlock(tree, pt) }
withBlock(tree) { withDefinedSyms(tree.stats) { super.typedBlock(tree, pt) } }

override def typedInlined(tree: untpd.Inlined, pt: Type)(implicit ctx: Context): Tree =
withDefinedSyms(tree.bindings) { super.typedInlined(tree, pt) }
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ class Typer extends Namer
*/
protected def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(implicit ctx: Context): Tree = {
def ascribeType(tree: Tree, pt: Type): Tree = tree match {
case block @ Block(stats, expr) =>
case block @ Block(stats, expr) if !expr.isInstanceOf[Closure] =>
val expr1 = ascribeType(expr, pt)
cpy.Block(block)(stats, expr1) withType expr1.tpe // no assignType here because avoid is redundant
case _ =>
Expand Down Expand Up @@ -3081,7 +3081,7 @@ class Typer extends Namer
}

tree match {
case _: MemberDef | _: PackageDef | _: Import | _: WithoutTypeOrPos[_] => tree
case _: MemberDef | _: PackageDef | _: Import | _: WithoutTypeOrPos[_] | _: Closure => tree
case _ => tree.tpe.widen match {
case tp: FlexType =>
ensureReported(tp)
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/erased-5.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ object Test {
type UU[T] = erased T => Int

def main(args: Array[String]): Unit = {
fun { x =>
x // error: Cannot use `erased` value in a context that is not `erased`
fun { x => // error: `Int => Int` not compatible with `erased Int => Int`
x
}

fun {
Expand Down
9 changes: 3 additions & 6 deletions tests/neg/i2146.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
object Test {
case class A()
case class B()

def foo[A, B]: given A => given B => Int = { given b: B =>
42 // error: found Int, required: given A => given B => Int
class Test {
def foo[A, B]: given A => given B => Int = { given b: B => // error: found Int, required: given A => given B => Int
42
}
}
8 changes: 4 additions & 4 deletions tests/neg/i5311.check
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- [E007] Type Mismatch Error: tests/neg/i5311.scala:11:27 -------------------------------------------------------------
-- [E007] Type Mismatch Error: tests/neg/i5311.scala:11:9 --------------------------------------------------------------
11 | baz((x : s.T[Int]) => x) // error
| ^
| Found: s.T[Int] => s.T[Int]
| Required: m.Foo
| ^^^^^^^^^^^^^^^^^^
| Found: s.T[Int] => s.T[Int]
| Required: m.Foo
4 changes: 2 additions & 2 deletions tests/neg/i5592.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ object Test {
}

val eqSymmetric2: Forall[[x] =>> (y: Obj) => (EQ[x, y.type]) => (EQ[y.type, x])] = {
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error // error
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error
}

val eqSymmetric3: Forall[[x] =>> Forall[[y] =>> EQ[x, y] => EQ[y, x]]] = {
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error // error
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error
}
}