From 37367988dd9c61c387389489c3e6379f555405a3 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 11 Jan 2019 14:33:21 +1000 Subject: [PATCH 1/2] Detect and deal with non-RefTree captures --- .../scala/async/internal/AsyncTransform.scala | 56 +++++++++------ .../scala/scala/async/internal/Lifter.scala | 21 ++++-- .../scala/async/run/late/LateExpansion.scala | 68 +++++++++++++++++-- 3 files changed, 114 insertions(+), 31 deletions(-) diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index ba0b522b..12372eda 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -158,29 +158,43 @@ trait AsyncTransform { // fields. Similarly, replace references to them with references to the field. // // This transform will only be run on the RHS of `def foo`. - val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match { - case _ if api.currentOwner == stateMachineClass => - api.default(tree) - case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => - api.atOwner(api.currentOwner) { - val fieldSym = tree.symbol - if (fieldSym.asTerm.isLazy) Literal(Constant(())) - else { - val lhs = atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) + val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => { + val result: Tree = tree match { + case _ if api.currentOwner == stateMachineClass => + api.default(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + api.atOwner(api.currentOwner) { + val fieldSym = tree.symbol + if (fieldSym.asTerm.isLazy) Literal(Constant(())) + else { + val lhs = atPos(tree.pos) { + gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) + } + treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) } - treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) } - } - case _: DefTree if liftedSyms(tree.symbol) => - EmptyTree - case Ident(name) if liftedSyms(tree.symbol) => - val fieldSym = tree.symbol - atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe) - } - case _ => - api.default(tree) + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + atPos(tree.pos) { + gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe) + } + case ta: TypeApply => + api.default(tree) + case _ => + api.default(tree) + } + val resultType = if (result.tpe eq null) null else result.tpe.map { + case TypeRef(pre, sym, args) if liftedSyms.contains(sym) => + val tp1 = internal.typeRef(thisType(sym.owner.asClass), sym, args) + tp1 + case SingleType(pre, sym) if liftedSyms.contains(sym) => + val tp1 = internal.singleType(thisType(sym.owner.asClass), sym) + tp1 + case tp => tp + } + setType(result, resultType) } val liftablesUseFields = liftables.map { diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala index dc6640da..7a049b46 100644 --- a/src/main/scala/scala/async/internal/Lifter.scala +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -1,6 +1,7 @@ package scala.async.internal import scala.collection.mutable +import scala.collection.mutable.ListBuffer trait Lifter { self: AsyncMacro => @@ -77,13 +78,25 @@ trait Lifter { // The direct references of each block, excluding references of `DefTree`-s which // are already accounted for. val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = { - val refs: List[(Int, Symbol)] = asyncStates.flatMap( - asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect { + val result = new mutable.LinkedHashMap[Int, ListBuffer[Symbol]]() + asyncStates.foreach( + asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).foreach(_.foreach { case rt: RefTree - if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + if symToDefiningState.contains(rt.symbol) => + result.getOrElseUpdate(asyncState.state, new ListBuffer) += rt.symbol + case tt: TypeTree => + tt.tpe.foreach { tp => + val termSym = tp.termSymbol + if (symToDefiningState.contains(termSym)) + result.getOrElseUpdate(asyncState.state, new ListBuffer) += termSym + val typeSym = tp.typeSymbol + if (symToDefiningState.contains(typeSym)) + result.getOrElseUpdate(asyncState.state, new ListBuffer) += typeSym + } + case _ => }) ) - toMultiMap(refs) + result.map { case (a, b) => (a, b.result())} } def liftableSyms: mutable.LinkedHashSet[Symbol] = { diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala index 42506fc3..7bdb1e48 100644 --- a/src/test/scala/scala/async/run/late/LateExpansion.scala +++ b/src/test/scala/scala/async/run/late/LateExpansion.scala @@ -7,7 +7,6 @@ import org.junit.{Assert, Test} import scala.annotation.StaticAnnotation import scala.annotation.meta.{field, getter} -import scala.async.TreeInterrogation import scala.async.internal.AsyncId import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader import scala.tools.nsc._ @@ -19,6 +18,56 @@ import scala.tools.nsc.transform.TypingTransformers // calls it from a new phase that runs after patmat. class LateExpansion { + @Test def testRewrittenApply(): Unit = { + val result = wrapAndRun( + """ + | object O { + | case class Foo(a: Any) + | } + | @autoawait def id(a: String) = a + | O.Foo + | id("foo") + id("bar") + | O.Foo(1) + | """.stripMargin) + assertEquals("Foo(1)", result.toString) + } + + @Test def testIsInstanceOfType(): Unit = { + val result = wrapAndRun( + """ + | class Outer + | @autoawait def id(a: String) = a + | val o = new Outer + | id("foo") + id("bar") + | ("": Object).isInstanceOf[o.type] + | """.stripMargin) + assertEquals(false, result) + } + + @Test def testIsInstanceOfTerm(): Unit = { + val result = wrapAndRun( + """ + | class Outer + | @autoawait def id(a: String) = a + | val o = new Outer + | id("foo") + id("bar") + | o.isInstanceOf[Outer] + | """.stripMargin) + assertEquals(true, result) + } + + @Test def testArrayLocalModule(): Unit = { + val result = wrapAndRun( + """ + | class Outer + | @autoawait def id(a: String) = a + | val O = "" + | id("foo") + id("bar") + | new Array[O.type](0) + | """.stripMargin) + assertEquals(classOf[Array[String]], result.getClass) + } + @Test def test0(): Unit = { val result = wrapAndRun( """ @@ -27,6 +76,7 @@ class LateExpansion { | """.stripMargin) assertEquals("foobar", result) } + @Test def testGuard(): Unit = { val result = wrapAndRun( """ @@ -143,6 +193,7 @@ class LateExpansion { |} | """.stripMargin) } + @Test def shadowing2(): Unit = { val result = run( """ @@ -369,6 +420,7 @@ class LateExpansion { } """) } + @Test def testNegativeArraySizeExceptionFine1(): Unit = { val result = run( """ @@ -389,18 +441,20 @@ class LateExpansion { } """) } + private def createTempDir(): File = { val f = File.createTempFile("output", "") f.delete() f.mkdirs() f } + def run(code: String): Any = { - // settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn") val out = createTempDir() try { val reporter = new StoreReporter val settings = new Settings(println(_)) + //settings.processArgumentString("-Xprint:refchecks,patmat,postpatmat,jvm -nowarn") settings.outdir.value = out.getAbsolutePath settings.embeddedDefaults(getClass.getClassLoader) val isInSBT = !settings.classpath.isSetByUser @@ -432,6 +486,7 @@ class LateExpansion { } abstract class LatePlugin extends Plugin { + import global._ override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { @@ -448,16 +503,16 @@ abstract class LatePlugin extends Plugin { super.transform(tree) match { case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) - case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) => + case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi]) => localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil)) case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { - deriveDefDef(dd){ rhs: Tree => + deriveDefDef(dd) { rhs: Tree => val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) localTyper.typed(atPos(dd.pos)(invoke)) } } case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) { - deriveValDef(vd){ rhs: Tree => + deriveValDef(vd) { rhs: Tree => val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) localTyper.typed(atPos(vd.pos)(invoke)) } @@ -468,6 +523,7 @@ abstract class LatePlugin extends Plugin { } } } + override def newPhase(prev: Phase): Phase = new StdPhase(prev) { override def apply(unit: CompilationUnit): Unit = { val translated = newTransformer(unit).transformUnit(unit) @@ -476,7 +532,7 @@ abstract class LatePlugin extends Plugin { } } - override val runsAfter: List[String] = "patmat" :: Nil + override val runsAfter: List[String] = "refchecks" :: Nil override val phaseName: String = "postpatmat" }) From 9bf63b6e8e3e4cdb88bda97d3905cf9aa4935576 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 25 Jan 2019 17:54:37 +1000 Subject: [PATCH 2/2] Less ambitious, more compatible, version of previous commit --- .../scala/async/internal/AsyncTransform.scala | 71 ++++++++++--------- .../scala/scala/async/TreeInterrogation.scala | 30 +++++--- .../scala/async/run/late/LateExpansion.scala | 3 +- 3 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 12372eda..b35173a1 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -154,38 +154,9 @@ trait AsyncTransform { sym.asModule.moduleClass.setOwner(stateMachineClass) } } - // Replace the ValDefs in the splicee with Assigns to the corresponding lifted - // fields. Similarly, replace references to them with references to the field. - // - // This transform will only be run on the RHS of `def foo`. - val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => { - val result: Tree = tree match { - case _ if api.currentOwner == stateMachineClass => - api.default(tree) - case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => - api.atOwner(api.currentOwner) { - val fieldSym = tree.symbol - if (fieldSym.asTerm.isLazy) Literal(Constant(())) - else { - val lhs = atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) - } - treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) - } - } - case _: DefTree if liftedSyms(tree.symbol) => - EmptyTree - case Ident(name) if liftedSyms(tree.symbol) => - val fieldSym = tree.symbol - atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe) - } - case ta: TypeApply => - api.default(tree) - case _ => - api.default(tree) - } - val resultType = if (result.tpe eq null) null else result.tpe.map { + + def adjustType(tree: Tree): Tree = { + val resultType = if (tree.tpe eq null) null else tree.tpe.map { case TypeRef(pre, sym, args) if liftedSyms.contains(sym) => val tp1 = internal.typeRef(thisType(sym.owner.asClass), sym, args) tp1 @@ -194,7 +165,41 @@ trait AsyncTransform { tp1 case tp => tp } - setType(result, resultType) + setType(tree, resultType) + } + + // Replace the ValDefs in the splicee with Assigns to the corresponding lifted + // fields. Similarly, replace references to them with references to the field. + // + // This transform will only be run on the RHS of `def foo`. + val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match { + case _ if api.currentOwner == stateMachineClass => + api.default(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + api.atOwner(api.currentOwner) { + val fieldSym = tree.symbol + if (fieldSym.asTerm.isLazy) Literal(Constant(())) + else { + val lhs = atPos(tree.pos) { + gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) + } + treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) + } + } + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + atPos(tree.pos) { + gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe) + } + case sel @ Select(n@New(tt: TypeTree), nme.CONSTRUCTOR) => + adjustType(sel) + adjustType(n) + adjustType(tt) + sel + case _ => + api.default(tree) } val liftablesUseFields = liftables.map { diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index cc4febc2..1484f321 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -70,17 +70,29 @@ object TreeInterrogationApp extends App { val tree = tb.parse( """ | import scala.async.internal.AsyncId._ - | async { - | var b = true - | while(await(b)) { - | b = false - | } - | (1, 1) match { - | case (x, y) => await(2); println(x) - | } - | await(b) + | trait QBound { type D; trait ResultType { case class Inner() }; def toResult: ResultType = ??? } + | trait QD[Q <: QBound] { + | val operation: Q + | type D = operation.D | } | + | async { + | if (!"".isEmpty) { + | val treeResult = null.asInstanceOf[QD[QBound]] + | await(0) + | val y = treeResult.operation + | type RD = treeResult.operation.D + | (null: Object) match { + | case (_, _: RD) => ??? + | case _ => val x = y.toResult; x.Inner() + | } + | await(1) + | (y, null.asInstanceOf[RD]) + | "" + | } + | + | } + | | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala index 7bdb1e48..5261209b 100644 --- a/src/test/scala/scala/async/run/late/LateExpansion.scala +++ b/src/test/scala/scala/async/run/late/LateExpansion.scala @@ -3,7 +3,7 @@ package scala.async.run.late import java.io.File import junit.framework.Assert.assertEquals -import org.junit.{Assert, Test} +import org.junit.{Assert, Ignore, Test} import scala.annotation.StaticAnnotation import scala.annotation.meta.{field, getter} @@ -32,6 +32,7 @@ class LateExpansion { assertEquals("Foo(1)", result.toString) } + @Ignore("Need to use adjustType more pervasively in AsyncTransform, but that exposes bugs in {Type, ... }Symbol's cache invalidation") @Test def testIsInstanceOfType(): Unit = { val result = wrapAndRun( """