Skip to content

Commit e61ef5d

Browse files
committed
Refactor optimizations
Most optimizations are similar in the way they traverse a plan. This commit factors out the commonalities in a base class `PlanTransform`.
1 parent 661f392 commit e61ef5d

File tree

1 file changed

+133
-146
lines changed

1 file changed

+133
-146
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 133 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -336,38 +336,72 @@ object PatternMatcher {
336336

337337
// ----- Optimizing plans ---------------
338338

339+
/** A superclass for plan transforms */
340+
class PlanTransform extends (Plan => Plan) {
341+
protected val treeMap = new TreeMap {
342+
override def transform(tree: Tree)(implicit ctx: Context) = tree
343+
}
344+
def apply(tree: Tree): Tree = treeMap.transform(tree)
345+
def apply(plan: TestPlan): Plan = {
346+
plan.scrutinee = apply(plan.scrutinee)
347+
plan.onSuccess = apply(plan.onSuccess)
348+
plan.onFailure = apply(plan.onFailure)
349+
plan
350+
}
351+
def apply(plan: LetPlan): Plan = {
352+
plan.body = apply(plan.body)
353+
initializer(plan.sym) = apply(initializer(plan.sym))
354+
plan
355+
}
356+
def apply(plan: LabelledPlan): Plan = {
357+
plan.body = apply(plan.body)
358+
labelled(plan.sym) = apply(labelled(plan.sym))
359+
plan
360+
}
361+
def apply(plan: CallPlan): Plan = plan
362+
def apply(plan: Plan): Plan = plan match {
363+
case plan: TestPlan => apply(plan)
364+
case plan: LetPlan => apply(plan)
365+
case plan: LabelledPlan => apply(plan)
366+
case plan: CallPlan => apply(plan)
367+
case plan: CodePlan => plan
368+
}
369+
}
370+
339371
/** Reference counts for all variables and labels */
340-
private def referenceCount(plan: Plan): collection.Map[Symbol, Int] = {
341-
val count = new mutable.HashMap[Symbol, Int] {
342-
override def default(key: Symbol) = 0
343-
}
344-
val refCounter = new TreeTraverser {
345-
def traverse(tree: Tree)(implicit ctx: Context) = tree match {
346-
case tree: Ident =>
347-
if (initializer contains tree.symbol) count(tree.symbol) += 1
348-
case _ =>
349-
traverseChildren(tree)
372+
def referenceCount(plan: Plan): collection.Map[Symbol, Int] = {
373+
object RefCount extends PlanTransform {
374+
val count = new mutable.HashMap[Symbol, Int] {
375+
override def default(key: Symbol) = 0
376+
}
377+
override val treeMap = new TreeMap {
378+
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
379+
case tree: Ident =>
380+
if (initializer contains tree.symbol) count(tree.symbol) += 1
381+
tree
382+
case _ =>
383+
super.transform(tree)
384+
}
385+
}
386+
override def apply(plan: LetPlan): Plan = {
387+
apply(plan.body)
388+
if (count(plan.sym) != 0 || !isPatmatGenerated(plan.sym))
389+
apply(initializer(plan.sym))
390+
plan
391+
}
392+
override def apply(plan: LabelledPlan): Plan = {
393+
apply(plan.body)
394+
if (count(plan.sym) != 0)
395+
apply(labelled(plan.sym))
396+
plan
397+
}
398+
override def apply(plan: CallPlan): Plan = {
399+
count(plan.label) += 1
400+
plan
350401
}
351402
}
352-
def traverse(plan: Plan): Unit = plan match {
353-
case plan: TestPlan =>
354-
refCounter.traverse(plan.scrutinee)
355-
traverse(plan.onSuccess)
356-
traverse(plan.onFailure)
357-
case LetPlan(sym, body) =>
358-
traverse(body)
359-
if (count(sym) != 0 || !isPatmatGenerated(sym))
360-
refCounter.traverse(initializer(sym))
361-
case LabelledPlan(sym, body) =>
362-
traverse(body)
363-
if (count(sym) != 0) traverse(labelled(sym))
364-
case CodePlan(_) =>
365-
;
366-
case CallPlan(label) =>
367-
count(label) += 1
368-
}
369-
traverse(plan)
370-
count
403+
RefCount(plan)
404+
RefCount.count
371405
}
372406

373407
/** Rewrite everywhere
@@ -384,40 +418,32 @@ object PatternMatcher {
384418
* -->
385419
* let L2 = B2 in let L1 = B1 in E
386420
*/
387-
private def hoistLabels(plan: Plan): Plan = plan match {
388-
case plan: TestPlan =>
421+
object hoistLabels extends PlanTransform {
422+
override def apply(plan: TestPlan): Plan =
389423
plan.onSuccess match {
390424
case lp @ LabelledPlan(sym, body) =>
391425
plan.onSuccess = body
392426
lp.body = plan
393-
hoistLabels(lp)
427+
apply(lp)
394428
case _ =>
395429
plan.onFailure match {
396430
case lp @ LabelledPlan(sym, body) =>
397431
plan.onFailure = body
398432
lp.body = plan
399-
hoistLabels(lp)
433+
apply(lp)
400434
case _ =>
401-
plan.onSuccess = hoistLabels(plan.onSuccess)
402-
plan.onFailure = hoistLabels(plan.onFailure)
403-
plan
435+
super.apply(plan)
404436
}
405437
}
406-
case plan @ LabelledPlan(sym, body) =>
407-
labelled(sym) match {
438+
override def apply(plan: LabelledPlan): Plan =
439+
labelled(plan.sym) match {
408440
case plan1: LabelledPlan =>
409-
labelled(sym) = plan1.body
441+
labelled(plan.sym) = plan1.body
410442
plan1.body = plan
411-
hoistLabels(plan1)
443+
apply(plan1)
412444
case _ =>
413-
plan.body = hoistLabels(plan.body)
414-
plan
445+
super.apply(plan)
415446
}
416-
case plan @ LetPlan(_, body) =>
417-
plan.body = hoistLabels(plan.body)
418-
plan
419-
case _ =>
420-
plan
421447
}
422448

423449
/** Eliminate tests that are redundant (known to be true or false).
@@ -475,26 +501,25 @@ object PatternMatcher {
475501
/** The tests with known outcomes valid at entry to label */
476502
val seenAtLabel = mutable.HashMap[Symbol, SeenTests]()
477503

478-
def transform(plan: Plan, seenTests: SeenTests): Plan = plan match {
479-
case plan: TestPlan =>
504+
class ElimRedundant(seenTests: SeenTests) extends PlanTransform {
505+
override def apply(plan: TestPlan): Plan = {
480506
val normPlan = normalize(plan)
481507
seenTests.get(normPlan) match {
482508
case Some(outcome) =>
483-
transform(if (outcome) plan.onSuccess else plan.onFailure, seenTests)
509+
apply(if (outcome) plan.onSuccess else plan.onFailure)
484510
case None =>
485-
plan.onSuccess = transform(plan.onSuccess, seenTests + (normPlan -> true))
486-
plan.onFailure = transform(plan.onFailure, seenTests + (normPlan -> false))
511+
plan.onSuccess = new ElimRedundant(seenTests + (normPlan -> true))(plan.onSuccess)
512+
plan.onFailure = new ElimRedundant(seenTests + (normPlan -> false))(plan.onFailure)
487513
plan
488514
}
489-
case plan @ LetPlan(_, body) =>
490-
plan.body = transform(body, seenTests)
491-
plan
492-
case plan @ LabelledPlan(label, body) =>
493-
plan.body = transform(body, seenTests)
494-
for (seenTests1 <- seenAtLabel.get(label))
495-
labelled(label) = transform(labelled(label), seenTests1)
515+
}
516+
override def apply(plan: LabelledPlan): Plan = {
517+
plan.body = apply(plan.body)
518+
for (seenTests1 <- seenAtLabel.get(plan.sym))
519+
labelled(plan.sym) = new ElimRedundant(seenTests1)(labelled(plan.sym))
496520
plan
497-
case plan: CallPlan =>
521+
}
522+
override def apply(plan: CallPlan): Plan = {
498523
val label = plan.label
499524
def redirect(target: Plan): Plan = {
500525
def forward(tst: TestPlan) = seenTests.get(tst) match {
@@ -510,18 +535,17 @@ object PatternMatcher {
510535
}
511536
redirect(labelled(label)) match {
512537
case target: CallPlan =>
513-
transform(target, seenTests)
538+
apply(target)
514539
case _ =>
515540
seenAtLabel(label) = seenAtLabel.get(label) match {
516541
case Some(seenTests1) => intersect(seenTests1, seenTests)
517542
case none => seenTests
518543
}
519544
plan
520545
}
521-
case _: CodePlan =>
522-
plan
546+
}
523547
}
524-
transform(plan, Map())
548+
new ElimRedundant(Map())(plan)
525549
}
526550

527551
/** Inline labelled blocks that are referenced only once.
@@ -530,29 +554,15 @@ object PatternMatcher {
530554
private def inlineLabelled(plan: Plan) = {
531555
val refCount = referenceCount(plan)
532556
def toDrop(sym: Symbol) = labelled.contains(sym) && refCount(sym) <= 1
533-
def transform(plan: Plan): Plan = plan match {
534-
case plan: TestPlan =>
535-
plan.onSuccess = transform(plan.onSuccess)
536-
plan.onFailure = transform(plan.onFailure)
537-
plan
538-
case plan @ LetPlan(_, body) =>
539-
plan.body = transform(body)
540-
plan
541-
case plan @ LabelledPlan(label, body) =>
542-
val body1 = transform(body)
543-
if (toDrop(label)) body1
544-
else {
545-
labelled(label) = transform(labelled(label))
546-
plan.body = body1
547-
plan
548-
}
549-
case CallPlan(label) =>
550-
if (refCount(label) == 1) transform(labelled(label))
557+
class Inliner extends PlanTransform {
558+
override def apply(plan: LabelledPlan): Plan =
559+
if (toDrop(plan.sym)) apply(plan.body) else super.apply(plan)
560+
override def apply(plan: CallPlan): Plan = {
561+
if (refCount(plan.label) == 1) apply(labelled(plan.label))
551562
else plan
552-
case plan: CodePlan =>
553-
plan
563+
}
554564
}
555-
transform(plan)
565+
(new Inliner)(plan)
556566
}
557567

558568
/** Merge variables that have the same right hand side
@@ -566,51 +576,40 @@ object PatternMatcher {
566576
override def hashCode: Int = tree.hash
567577
}
568578

569-
val treeMap = new TreeMap {
570-
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
571-
case tree: Ident =>
572-
val sym = tree.symbol
573-
initializer.get(sym) match {
574-
case Some(id: Ident @unchecked)
575-
if isPatmatGenerated(sym) && isPatmatGenerated(id.symbol) =>
576-
transform(id)
577-
case none => tree
578-
}
579-
case _ =>
580-
super.transform(tree)
579+
class Merge(seenVars: Map[RHS, Symbol]) extends PlanTransform {
580+
override val treeMap = new TreeMap {
581+
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
582+
case tree: Ident =>
583+
val sym = tree.symbol
584+
initializer.get(sym) match {
585+
case Some(id: Ident @unchecked)
586+
if isPatmatGenerated(sym) && isPatmatGenerated(id.symbol) =>
587+
transform(id)
588+
case none => tree
589+
}
590+
case _ =>
591+
super.transform(tree)
592+
}
581593
}
582-
}
583594

584-
def transform(plan: Plan, seenVars: Map[RHS, Symbol]): Plan = plan match {
585-
case plan: TestPlan =>
586-
plan.scrutinee = treeMap.transform(plan.scrutinee)
587-
plan.onSuccess = transform(plan.onSuccess, seenVars)
588-
plan.onFailure = transform(plan.onFailure, seenVars)
589-
plan
590-
case plan @ LetPlan(sym, body) =>
595+
override def apply(plan: LetPlan): Plan = {
591596
val seenVars1 =
592-
if (isPatmatGenerated(sym)) {
593-
val thisRhs = new RHS(initializer(sym))
597+
if (isPatmatGenerated(plan.sym)) {
598+
val thisRhs = new RHS(initializer(plan.sym))
594599
seenVars.get(thisRhs) match {
595600
case Some(seen) =>
596-
initializer(sym) = ref(seen)
601+
initializer(plan.sym) = ref(seen)
597602
case none =>
598603
}
599-
seenVars.updated(thisRhs, sym)
604+
seenVars.updated(thisRhs, plan.sym)
600605
}
601606
else seenVars
602-
initializer(sym) = treeMap.transform(initializer(sym))
603-
plan.body = transform(body, seenVars1)
604-
plan
605-
case plan @ LabelledPlan(label, body) =>
606-
labelled(label) = transform(labelled(label), seenVars)
607-
plan.body = transform(body, seenVars)
608-
plan
609-
case _ =>
607+
initializer(plan.sym) = apply(initializer(plan.sym))
608+
plan.body = new Merge(seenVars1)(plan.body)
610609
plan
610+
}
611611
}
612-
613-
transform(plan, Map())
612+
(new Merge(Map()))(plan)
614613
}
615614

616615
/** Inline let-bound trees and labelled blocks that are referenced only once.
@@ -624,38 +623,26 @@ object PatternMatcher {
624623
def toDrop(sym: Symbol) =
625624
initializer.contains(sym) && isPatmatGenerated(sym) && refCount(sym) <= 1 && sym != topSym
626625

627-
val treeMap = new TreeMap {
628-
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
629-
case tree: Ident =>
630-
val sym = tree.symbol
631-
if (toDrop(sym)) transform(initializer(sym)) else tree
632-
case _ =>
633-
super.transform(tree)
626+
object Inliner extends PlanTransform {
627+
override val treeMap = new TreeMap {
628+
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
629+
case tree: Ident =>
630+
val sym = tree.symbol
631+
if (toDrop(sym)) transform(initializer(sym)) else tree
632+
case _ =>
633+
super.transform(tree)
634+
}
634635
}
635-
}
636-
637-
def transform(plan: Plan): Plan = plan match {
638-
case plan: TestPlan =>
639-
plan.scrutinee = treeMap.transform(plan.scrutinee)
640-
plan.onSuccess = transform(plan.onSuccess)
641-
plan.onFailure = transform(plan.onFailure)
642-
plan
643-
case plan @ LetPlan(sym, body) =>
644-
val body1 = transform(body)
645-
if (toDrop(sym)) body1
636+
override def apply(plan: LetPlan): Plan = {
637+
if (toDrop(plan.sym)) apply(plan.body)
646638
else {
647-
initializer(sym) = treeMap.transform(initializer(sym))
648-
plan.body = body1
639+
initializer(plan.sym) = apply(initializer(plan.sym))
640+
plan.body = apply(plan.body)
649641
plan
650642
}
651-
case plan @ LabelledPlan(label, body) =>
652-
labelled(label) = transform(labelled(label))
653-
plan.body = transform(plan.body)
654-
plan
655-
case _ =>
656-
plan
643+
}
657644
}
658-
transform(plan)
645+
Inliner(plan)
659646
}
660647

661648
// ----- Generating trees from plans ---------------

0 commit comments

Comments
 (0)