@@ -92,6 +92,8 @@ object TreeTransforms {
92
92
def prepareForPackageDef (tree : PackageDef )(implicit ctx : Context ) = this
93
93
def prepareForStats (trees : List [Tree ])(implicit ctx : Context ) = this
94
94
95
+ def prepareForUnit (tree : Tree )(implicit ctx : Context ) = this
96
+
95
97
def transformIdent (tree : Ident )(implicit ctx : Context , info : TransformerInfo ): Tree = tree
96
98
def transformSelect (tree : Select )(implicit ctx : Context , info : TransformerInfo ): Tree = tree
97
99
def transformThis (tree : This )(implicit ctx : Context , info : TransformerInfo ): Tree = tree
@@ -125,6 +127,8 @@ object TreeTransforms {
125
127
def transformStats (trees : List [Tree ])(implicit ctx : Context , info : TransformerInfo ): List [Tree ] = trees
126
128
def transformOther (tree : Tree )(implicit ctx : Context , info : TransformerInfo ): Tree = tree
127
129
130
+ def transformUnit (tree : Tree )(implicit ctx : Context , info : TransformerInfo ): Tree = tree
131
+
128
132
/** Transform tree using all transforms of current group (including this one) */
129
133
def transform (tree : Tree )(implicit ctx : Context , info : TransformerInfo ): Tree = info.group.transform(tree, info, 0 )
130
134
@@ -273,6 +277,7 @@ object TreeTransforms {
273
277
nxPrepTemplate = index(transformations, " prepareForTemplate" )
274
278
nxPrepPackageDef = index(transformations, " prepareForPackageDef" )
275
279
nxPrepStats = index(transformations, " prepareForStats" )
280
+ nxPrepUnit = index(transformations, " prepareForUnit" )
276
281
277
282
nxTransIdent = index(transformations, " transformIdent" )
278
283
nxTransSelect = index(transformations, " transformSelect" )
@@ -305,6 +310,7 @@ object TreeTransforms {
305
310
nxTransTemplate = index(transformations, " transformTemplate" )
306
311
nxTransPackageDef = index(transformations, " transformPackageDef" )
307
312
nxTransStats = index(transformations, " transformStats" )
313
+ nxTransUnit = index(transformations, " transformUnit" )
308
314
nxTransOther = index(transformations, " transformOther" )
309
315
}
310
316
@@ -412,6 +418,7 @@ object TreeTransforms {
412
418
var nxPrepTemplate : Array [Int ] = _
413
419
var nxPrepPackageDef : Array [Int ] = _
414
420
var nxPrepStats : Array [Int ] = _
421
+ var nxPrepUnit : Array [Int ] = _
415
422
416
423
var nxTransIdent : Array [Int ] = _
417
424
var nxTransSelect : Array [Int ] = _
@@ -444,6 +451,7 @@ object TreeTransforms {
444
451
var nxTransTemplate : Array [Int ] = _
445
452
var nxTransPackageDef : Array [Int ] = _
446
453
var nxTransStats : Array [Int ] = _
454
+ var nxTransUnit : Array [Int ] = _
447
455
var nxTransOther : Array [Int ] = _
448
456
}
449
457
@@ -454,7 +462,7 @@ object TreeTransforms {
454
462
455
463
override def run (implicit ctx : Context ): Unit = {
456
464
val curTree = ctx.compilationUnit.tpdTree
457
- val newTree = transform (curTree)
465
+ val newTree = macroTransform (curTree)
458
466
ctx.compilationUnit.tpdTree = newTree
459
467
}
460
468
@@ -517,16 +525,19 @@ object TreeTransforms {
517
525
val prepForTemplate : Mutator [Template ] = (trans, tree, ctx) => trans.prepareForTemplate(tree)(ctx)
518
526
val prepForPackageDef : Mutator [PackageDef ] = (trans, tree, ctx) => trans.prepareForPackageDef(tree)(ctx)
519
527
val prepForStats : Mutator [List [Tree ]] = (trans, trees, ctx) => trans.prepareForStats(trees)(ctx)
528
+ val prepForUnit : Mutator [Tree ] = (trans, tree, ctx) => trans.prepareForUnit(tree)(ctx)
520
529
521
- def transform (t : Tree )(implicit ctx : Context ): Tree = {
530
+ def macroTransform (t : Tree )(implicit ctx : Context ): Tree = {
522
531
val initialTransformations = transformations
523
532
val info = new TransformerInfo (initialTransformations, new NXTransformations (initialTransformations), this )
524
533
initialTransformations.zipWithIndex.foreach {
525
534
case (transform, id) =>
526
535
transform.idx = id
527
536
transform.init(ctx, info)
528
537
}
529
- transform(t, info, 0 )
538
+ implicit val mutatedInfo : TransformerInfo = mutateTransformers(info, prepForUnit, info.nx.nxPrepUnit, t, 0 )
539
+ if (mutatedInfo eq null ) t
540
+ else goUnit(transform(t, mutatedInfo, 0 ), mutatedInfo.nx.nxTransUnit(0 ))
530
541
}
531
542
532
543
@ tailrec
@@ -859,6 +870,15 @@ object TreeTransforms {
859
870
} else tree
860
871
}
861
872
873
+ @ tailrec
874
+ final private [TreeTransforms ] def goUnit (tree : Tree , cur : Int )(implicit ctx : Context , info : TransformerInfo ): Tree = {
875
+ if (cur < info.transformers.length) {
876
+ val trans = info.transformers(cur)
877
+ val t = trans.transformUnit(tree)(ctx.withPhase(trans.treeTransformPhase), info)
878
+ goUnit(t, info.nx.nxTransUnit(cur + 1 ))
879
+ } else tree
880
+ }
881
+
862
882
final private [TreeTransforms ] def goOther (tree : Tree , cur : Int )(implicit ctx : Context , info : TransformerInfo ): Tree = {
863
883
if (cur < info.transformers.length) {
864
884
val trans = info.transformers(cur)
@@ -1219,5 +1239,4 @@ object TreeTransforms {
1219
1239
def transformSubTrees [Tr <: Tree ](trees : List [Tr ], info : TransformerInfo , current : Int )(implicit ctx : Context ): List [Tr ] =
1220
1240
transformTrees(trees, info, current)(ctx).asInstanceOf [List [Tr ]]
1221
1241
}
1222
-
1223
1242
}
0 commit comments