Skip to content

Commit cf12129

Browse files
committed
Added methods to prepare-for and transform a complete compilation unit tree.
Should replace destructive inits.
1 parent 7978a5f commit cf12129

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

src/dotty/tools/dotc/transform/TreeTransform.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ object TreeTransforms {
9292
def prepareForPackageDef(tree: PackageDef)(implicit ctx: Context) = this
9393
def prepareForStats(trees: List[Tree])(implicit ctx: Context) = this
9494

95+
def prepareForUnit(tree: Tree)(implicit ctx: Context) = this
96+
9597
def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = tree
9698
def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo): Tree = tree
9799
def transformThis(tree: This)(implicit ctx: Context, info: TransformerInfo): Tree = tree
@@ -125,6 +127,8 @@ object TreeTransforms {
125127
def transformStats(trees: List[Tree])(implicit ctx: Context, info: TransformerInfo): List[Tree] = trees
126128
def transformOther(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree
127129

130+
def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree
131+
128132
/** Transform tree using all transforms of current group (including this one) */
129133
def transform(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = info.group.transform(tree, info, 0)
130134

@@ -273,6 +277,7 @@ object TreeTransforms {
273277
nxPrepTemplate = index(transformations, "prepareForTemplate")
274278
nxPrepPackageDef = index(transformations, "prepareForPackageDef")
275279
nxPrepStats = index(transformations, "prepareForStats")
280+
nxPrepUnit = index(transformations, "prepareForUnit")
276281

277282
nxTransIdent = index(transformations, "transformIdent")
278283
nxTransSelect = index(transformations, "transformSelect")
@@ -305,6 +310,7 @@ object TreeTransforms {
305310
nxTransTemplate = index(transformations, "transformTemplate")
306311
nxTransPackageDef = index(transformations, "transformPackageDef")
307312
nxTransStats = index(transformations, "transformStats")
313+
nxTransUnit = index(transformations, "transformUnit")
308314
nxTransOther = index(transformations, "transformOther")
309315
}
310316

@@ -412,6 +418,7 @@ object TreeTransforms {
412418
var nxPrepTemplate: Array[Int] = _
413419
var nxPrepPackageDef: Array[Int] = _
414420
var nxPrepStats: Array[Int] = _
421+
var nxPrepUnit: Array[Int] = _
415422

416423
var nxTransIdent: Array[Int] = _
417424
var nxTransSelect: Array[Int] = _
@@ -444,6 +451,7 @@ object TreeTransforms {
444451
var nxTransTemplate: Array[Int] = _
445452
var nxTransPackageDef: Array[Int] = _
446453
var nxTransStats: Array[Int] = _
454+
var nxTransUnit: Array[Int] = _
447455
var nxTransOther: Array[Int] = _
448456
}
449457

@@ -454,7 +462,7 @@ object TreeTransforms {
454462

455463
override def run(implicit ctx: Context): Unit = {
456464
val curTree = ctx.compilationUnit.tpdTree
457-
val newTree = transform(curTree)
465+
val newTree = macroTransform(curTree)
458466
ctx.compilationUnit.tpdTree = newTree
459467
}
460468

@@ -517,16 +525,19 @@ object TreeTransforms {
517525
val prepForTemplate: Mutator[Template] = (trans, tree, ctx) => trans.prepareForTemplate(tree)(ctx)
518526
val prepForPackageDef: Mutator[PackageDef] = (trans, tree, ctx) => trans.prepareForPackageDef(tree)(ctx)
519527
val prepForStats: Mutator[List[Tree]] = (trans, trees, ctx) => trans.prepareForStats(trees)(ctx)
528+
val prepForUnit: Mutator[Tree] = (trans, tree, ctx) => trans.prepareForUnit(tree)(ctx)
520529

521-
def transform(t: Tree)(implicit ctx: Context): Tree = {
530+
def macroTransform(t: Tree)(implicit ctx: Context): Tree = {
522531
val initialTransformations = transformations
523532
val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this)
524533
initialTransformations.zipWithIndex.foreach {
525534
case (transform, id) =>
526535
transform.idx = id
527536
transform.init(ctx, info)
528537
}
529-
transform(t, info, 0)
538+
implicit val mutatedInfo: TransformerInfo = mutateTransformers(info, prepForUnit, info.nx.nxPrepUnit, t, 0)
539+
if (mutatedInfo eq null) t
540+
else goUnit(transform(t, mutatedInfo, 0), mutatedInfo.nx.nxTransUnit(0))
530541
}
531542

532543
@tailrec
@@ -859,6 +870,15 @@ object TreeTransforms {
859870
} else tree
860871
}
861872

873+
@tailrec
874+
final private[TreeTransforms] def goUnit(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
875+
if (cur < info.transformers.length) {
876+
val trans = info.transformers(cur)
877+
val t = trans.transformUnit(tree)(ctx.withPhase(trans.treeTransformPhase), info)
878+
goUnit(t, info.nx.nxTransUnit(cur + 1))
879+
} else tree
880+
}
881+
862882
final private[TreeTransforms] def goOther(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
863883
if (cur < info.transformers.length) {
864884
val trans = info.transformers(cur)
@@ -1219,5 +1239,4 @@ object TreeTransforms {
12191239
def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] =
12201240
transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]]
12211241
}
1222-
12231242
}

test/test/transform/TreeTransformerTest.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TreeTransformerTest extends DottyTest {
2424

2525
override def phaseName: String = "test"
2626
}
27-
val transformed = transformer.transform(tree)
27+
val transformed = transformer.macroTransform(tree)
2828

2929
Assert.assertTrue("returns same tree if unmodified",
3030
tree eq transformed
@@ -46,7 +46,7 @@ class TreeTransformerTest extends DottyTest {
4646

4747
override def phaseName: String = "test"
4848
}
49-
val transformed = transformer.transform(tree)
49+
val transformed = transformer.macroTransform(tree)
5050

5151
Assert.assertTrue("returns same tree if unmodified",
5252
transformed.toString.contains("List(ValDef(Modifiers(,,List()),d,TypeTree[TypeRef(ThisType(module class scala),Int)],Literal(Constant(2)))")
@@ -77,7 +77,7 @@ class TreeTransformerTest extends DottyTest {
7777
override def phaseName: String = "test"
7878

7979
}
80-
val tr = transformer.transform(tree).toString
80+
val tr = transformer.macroTransform(tree).toString
8181

8282
Assert.assertTrue("node can rewrite children",
8383
tr.contains("Literal(Constant(2))") && !tr.contains("Literal(Constant(-1))")
@@ -123,7 +123,7 @@ class TreeTransformerTest extends DottyTest {
123123

124124
override def phaseName: String = "test"
125125
}
126-
val tr = transformer.transform(tree).toString
126+
val tr = transformer.macroTransform(tree).toString
127127

128128
Assert.assertTrue("node can rewrite children",
129129
tr.contains("Literal(Constant(3))")
@@ -191,7 +191,7 @@ class TreeTransformerTest extends DottyTest {
191191

192192
override def phaseName: String = "test"
193193
}
194-
val tr = transformer.transform(tree).toString
194+
val tr = transformer.macroTransform(tree).toString
195195
Assert.assertTrue("transformations aren't invoked multiple times",
196196
transformed1 == 2 && transformed2 == 3
197197
)

0 commit comments

Comments
 (0)