Skip to content

Commit 3736798

Browse files
committed
Detect and deal with non-RefTree captures
1 parent 6d3cecb commit 3736798

File tree

3 files changed

+114
-31
lines changed

3 files changed

+114
-31
lines changed

src/main/scala/scala/async/internal/AsyncTransform.scala

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -158,29 +158,43 @@ trait AsyncTransform {
158158
// fields. Similarly, replace references to them with references to the field.
159159
//
160160
// This transform will only be run on the RHS of `def foo`.
161-
val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match {
162-
case _ if api.currentOwner == stateMachineClass =>
163-
api.default(tree)
164-
case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
165-
api.atOwner(api.currentOwner) {
166-
val fieldSym = tree.symbol
167-
if (fieldSym.asTerm.isLazy) Literal(Constant(()))
168-
else {
169-
val lhs = atPos(tree.pos) {
170-
gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym)
161+
val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => {
162+
val result: Tree = tree match {
163+
case _ if api.currentOwner == stateMachineClass =>
164+
api.default(tree)
165+
case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
166+
api.atOwner(api.currentOwner) {
167+
val fieldSym = tree.symbol
168+
if (fieldSym.asTerm.isLazy) Literal(Constant(()))
169+
else {
170+
val lhs = atPos(tree.pos) {
171+
gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym)
172+
}
173+
treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner)
171174
}
172-
treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner)
173175
}
174-
}
175-
case _: DefTree if liftedSyms(tree.symbol) =>
176-
EmptyTree
177-
case Ident(name) if liftedSyms(tree.symbol) =>
178-
val fieldSym = tree.symbol
179-
atPos(tree.pos) {
180-
gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe)
181-
}
182-
case _ =>
183-
api.default(tree)
176+
case _: DefTree if liftedSyms(tree.symbol) =>
177+
EmptyTree
178+
case Ident(name) if liftedSyms(tree.symbol) =>
179+
val fieldSym = tree.symbol
180+
atPos(tree.pos) {
181+
gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe)
182+
}
183+
case ta: TypeApply =>
184+
api.default(tree)
185+
case _ =>
186+
api.default(tree)
187+
}
188+
val resultType = if (result.tpe eq null) null else result.tpe.map {
189+
case TypeRef(pre, sym, args) if liftedSyms.contains(sym) =>
190+
val tp1 = internal.typeRef(thisType(sym.owner.asClass), sym, args)
191+
tp1
192+
case SingleType(pre, sym) if liftedSyms.contains(sym) =>
193+
val tp1 = internal.singleType(thisType(sym.owner.asClass), sym)
194+
tp1
195+
case tp => tp
196+
}
197+
setType(result, resultType)
184198
}
185199

186200
val liftablesUseFields = liftables.map {

src/main/scala/scala/async/internal/Lifter.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package scala.async.internal
22

33
import scala.collection.mutable
4+
import scala.collection.mutable.ListBuffer
45

56
trait Lifter {
67
self: AsyncMacro =>
@@ -77,13 +78,25 @@ trait Lifter {
7778
// The direct references of each block, excluding references of `DefTree`-s which
7879
// are already accounted for.
7980
val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = {
80-
val refs: List[(Int, Symbol)] = asyncStates.flatMap(
81-
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
81+
val result = new mutable.LinkedHashMap[Int, ListBuffer[Symbol]]()
82+
asyncStates.foreach(
83+
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).foreach(_.foreach {
8284
case rt: RefTree
83-
if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
85+
if symToDefiningState.contains(rt.symbol) =>
86+
result.getOrElseUpdate(asyncState.state, new ListBuffer) += rt.symbol
87+
case tt: TypeTree =>
88+
tt.tpe.foreach { tp =>
89+
val termSym = tp.termSymbol
90+
if (symToDefiningState.contains(termSym))
91+
result.getOrElseUpdate(asyncState.state, new ListBuffer) += termSym
92+
val typeSym = tp.typeSymbol
93+
if (symToDefiningState.contains(typeSym))
94+
result.getOrElseUpdate(asyncState.state, new ListBuffer) += typeSym
95+
}
96+
case _ =>
8497
})
8598
)
86-
toMultiMap(refs)
99+
result.map { case (a, b) => (a, b.result())}
87100
}
88101

89102
def liftableSyms: mutable.LinkedHashSet[Symbol] = {

src/test/scala/scala/async/run/late/LateExpansion.scala

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import org.junit.{Assert, Test}
77

88
import scala.annotation.StaticAnnotation
99
import scala.annotation.meta.{field, getter}
10-
import scala.async.TreeInterrogation
1110
import scala.async.internal.AsyncId
1211
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
1312
import scala.tools.nsc._
@@ -19,6 +18,56 @@ import scala.tools.nsc.transform.TypingTransformers
1918
// calls it from a new phase that runs after patmat.
2019
class LateExpansion {
2120

21+
@Test def testRewrittenApply(): Unit = {
22+
val result = wrapAndRun(
23+
"""
24+
| object O {
25+
| case class Foo(a: Any)
26+
| }
27+
| @autoawait def id(a: String) = a
28+
| O.Foo
29+
| id("foo") + id("bar")
30+
| O.Foo(1)
31+
| """.stripMargin)
32+
assertEquals("Foo(1)", result.toString)
33+
}
34+
35+
@Test def testIsInstanceOfType(): Unit = {
36+
val result = wrapAndRun(
37+
"""
38+
| class Outer
39+
| @autoawait def id(a: String) = a
40+
| val o = new Outer
41+
| id("foo") + id("bar")
42+
| ("": Object).isInstanceOf[o.type]
43+
| """.stripMargin)
44+
assertEquals(false, result)
45+
}
46+
47+
@Test def testIsInstanceOfTerm(): Unit = {
48+
val result = wrapAndRun(
49+
"""
50+
| class Outer
51+
| @autoawait def id(a: String) = a
52+
| val o = new Outer
53+
| id("foo") + id("bar")
54+
| o.isInstanceOf[Outer]
55+
| """.stripMargin)
56+
assertEquals(true, result)
57+
}
58+
59+
@Test def testArrayLocalModule(): Unit = {
60+
val result = wrapAndRun(
61+
"""
62+
| class Outer
63+
| @autoawait def id(a: String) = a
64+
| val O = ""
65+
| id("foo") + id("bar")
66+
| new Array[O.type](0)
67+
| """.stripMargin)
68+
assertEquals(classOf[Array[String]], result.getClass)
69+
}
70+
2271
@Test def test0(): Unit = {
2372
val result = wrapAndRun(
2473
"""
@@ -27,6 +76,7 @@ class LateExpansion {
2776
| """.stripMargin)
2877
assertEquals("foobar", result)
2978
}
79+
3080
@Test def testGuard(): Unit = {
3181
val result = wrapAndRun(
3282
"""
@@ -143,6 +193,7 @@ class LateExpansion {
143193
|}
144194
| """.stripMargin)
145195
}
196+
146197
@Test def shadowing2(): Unit = {
147198
val result = run(
148199
"""
@@ -369,6 +420,7 @@ class LateExpansion {
369420
}
370421
""")
371422
}
423+
372424
@Test def testNegativeArraySizeExceptionFine1(): Unit = {
373425
val result = run(
374426
"""
@@ -389,18 +441,20 @@ class LateExpansion {
389441
}
390442
""")
391443
}
444+
392445
private def createTempDir(): File = {
393446
val f = File.createTempFile("output", "")
394447
f.delete()
395448
f.mkdirs()
396449
f
397450
}
451+
398452
def run(code: String): Any = {
399-
// settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn")
400453
val out = createTempDir()
401454
try {
402455
val reporter = new StoreReporter
403456
val settings = new Settings(println(_))
457+
//settings.processArgumentString("-Xprint:refchecks,patmat,postpatmat,jvm -nowarn")
404458
settings.outdir.value = out.getAbsolutePath
405459
settings.embeddedDefaults(getClass.getClassLoader)
406460
val isInSBT = !settings.classpath.isSetByUser
@@ -432,6 +486,7 @@ class LateExpansion {
432486
}
433487

434488
abstract class LatePlugin extends Plugin {
489+
435490
import global._
436491

437492
override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
@@ -448,16 +503,16 @@ abstract class LatePlugin extends Plugin {
448503
super.transform(tree) match {
449504
case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
450505
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
451-
case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) =>
506+
case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi]) =>
452507
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil))
453508
case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
454-
deriveDefDef(dd){ rhs: Tree =>
509+
deriveDefDef(dd) { rhs: Tree =>
455510
val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
456511
localTyper.typed(atPos(dd.pos)(invoke))
457512
}
458513
}
459514
case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) {
460-
deriveValDef(vd){ rhs: Tree =>
515+
deriveValDef(vd) { rhs: Tree =>
461516
val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
462517
localTyper.typed(atPos(vd.pos)(invoke))
463518
}
@@ -468,6 +523,7 @@ abstract class LatePlugin extends Plugin {
468523
}
469524
}
470525
}
526+
471527
override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
472528
override def apply(unit: CompilationUnit): Unit = {
473529
val translated = newTransformer(unit).transformUnit(unit)
@@ -476,7 +532,7 @@ abstract class LatePlugin extends Plugin {
476532
}
477533
}
478534

479-
override val runsAfter: List[String] = "patmat" :: Nil
535+
override val runsAfter: List[String] = "refchecks" :: Nil
480536
override val phaseName: String = "postpatmat"
481537

482538
})

0 commit comments

Comments
 (0)