Skip to content

Commit b4377dc

Browse files
committed
More aggressive variable sharing
Employing label parameters, we can eliminate redundant expressions systematically. This produces now optimal code for reducable.scala.
1 parent e61ef5d commit b4377dc

File tree

4 files changed

+177
-43
lines changed

4 files changed

+177
-43
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
171171
def SyntheticValDef(name: TermName, rhs: Tree)(implicit ctx: Context): ValDef =
172172
ValDef(ctx.newSymbol(ctx.owner, name, Synthetic, rhs.tpe.widen, coord = rhs.pos), rhs)
173173

174+
def DefDef(sym: TermSymbol, tparams: List[TypeSymbol], vparamss: List[List[TermSymbol]],
175+
resultType: Type, rhs: Tree)(implicit ctx: Context): DefDef =
176+
ta.assignType(
177+
untpd.DefDef(sym.name, tparams map TypeDef, vparamss.nestedMap(ValDef(_)),
178+
TypeTree(resultType), rhs),
179+
sym)
180+
174181
def DefDef(sym: TermSymbol, rhs: Tree = EmptyTree)(implicit ctx: Context): DefDef =
175182
ta.assignType(DefDef(sym, Function.const(rhs) _), sym)
176183

@@ -199,14 +206,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
199206
val (vparamss, rtp) = valueParamss(mtp)
200207
val targs = tparams map (_.typeRef)
201208
val argss = vparamss.nestedMap(vparam => Ident(vparam.termRef))
202-
ta.assignType(
203-
untpd.DefDef(
204-
sym.name,
205-
tparams map TypeDef,
206-
vparamss.nestedMap(ValDef(_)),
207-
TypeTree(rtp),
208-
rhsFn(targs)(argss)),
209-
sym)
209+
DefDef(sym, tparams, vparamss, rtp, rhsFn(targs)(argss))
210210
}
211211

212212
def TypeDef(sym: TypeSymbol)(implicit ctx: Context): TypeDef =

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 112 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,23 @@ object PatternMatcher {
7474
private val initializer = mutable.Map[Symbol, Tree]()
7575
private val labelled = mutable.Map[Symbol, Plan]()
7676

77+
private def newVar(rhs: Tree, flags: FlagSet): TermSymbol =
78+
ctx.newSymbol(ctx.owner, PatMatStdBinderName.fresh(), Synthetic | Case | flags,
79+
sanitize(rhs.tpe), coord = rhs.pos)
80+
7781
/** The plan `let x = rhs in body(x)` where `x` is a fresh variable */
7882
private def letAbstract(rhs: Tree)(body: Symbol => Plan): Plan = {
79-
val vble = ctx.newSymbol(ctx.owner, PatMatStdBinderName.fresh(), Synthetic | Case,
80-
sanitize(rhs.tpe), coord = rhs.pos)
83+
val vble = newVar(rhs, EmptyFlags)
8184
initializer(vble) = rhs
8285
LetPlan(vble, body(vble))
8386
}
8487

8588
/** The plan `let l = labelled in body(l)` where `l` is a fresh label */
86-
private def labelAbstract(labeld: Plan)(body: Plan => Plan): Plan = {
89+
private def labelAbstract(labeld: Plan)(body: (=> Plan) => Plan): Plan = {
8790
val label = ctx.newSymbol(ctx.owner, PatMatCaseName.fresh(), Synthetic | Label | Method,
8891
MethodType(Nil, resultType))
8992
labelled(label) = labeld
90-
LabelledPlan(label, body(CallPlan(label)))
93+
LabelledPlan(label, body(CallPlan(label, Nil)), Nil)
9194
}
9295

9396
/** Was symbol generated by pattern matcher? */
@@ -127,9 +130,10 @@ object PatternMatcher {
127130
}
128131

129132
case class LetPlan(sym: TermSymbol, var body: Plan) extends Plan
130-
case class LabelledPlan(sym: TermSymbol, var body: Plan) extends Plan
133+
case class LabelledPlan(sym: TermSymbol, var body: Plan, var params: List[TermSymbol]) extends Plan
131134
case class CodePlan(var tree: Tree) extends Plan
132-
case class CallPlan(label: TermSymbol) extends Plan
135+
case class CallPlan(label: TermSymbol,
136+
var args: List[(/*formal*/TermSymbol, /*actual*/TermSymbol)]) extends Plan
133137

134138
object TestPlan {
135139
def apply(test: Test, sym: Symbol, pos: Position, ons: Plan, onf: Plan): TestPlan =
@@ -368,16 +372,36 @@ object PatternMatcher {
368372
}
369373
}
370374

375+
private class RefCounter extends PlanTransform {
376+
val count = new mutable.HashMap[Symbol, Int] {
377+
override def default(key: Symbol) = 0
378+
}
379+
}
380+
371381
/** Reference counts for all variables and labels */
372-
def referenceCount(plan: Plan): collection.Map[Symbol, Int] = {
373-
object RefCount extends PlanTransform {
374-
val count = new mutable.HashMap[Symbol, Int] {
375-
override def default(key: Symbol) = 0
382+
private def labelRefCount(plan: Plan): collection.Map[Symbol, Int] = {
383+
object refCounter extends RefCounter {
384+
override def apply(plan: LabelledPlan): Plan = {
385+
apply(plan.body)
386+
if (count(plan.sym) != 0) apply(labelled(plan.sym))
387+
plan
388+
}
389+
override def apply(plan: CallPlan): Plan = {
390+
count(plan.label) += 1
391+
plan
376392
}
393+
}
394+
refCounter(plan)
395+
refCounter.count
396+
}
397+
398+
/** Reference counts for all variables and labels */
399+
private def varRefCount(plan: Plan): collection.Map[Symbol, Int] = {
400+
object refCounter extends RefCounter {
377401
override val treeMap = new TreeMap {
378402
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
379403
case tree: Ident =>
380-
if (initializer contains tree.symbol) count(tree.symbol) += 1
404+
if (isPatmatGenerated(tree.symbol)) count(tree.symbol) += 1
381405
tree
382406
case _ =>
383407
super.transform(tree)
@@ -390,18 +414,18 @@ object PatternMatcher {
390414
plan
391415
}
392416
override def apply(plan: LabelledPlan): Plan = {
417+
apply(labelled(plan.sym))
393418
apply(plan.body)
394-
if (count(plan.sym) != 0)
395-
apply(labelled(plan.sym))
396419
plan
397420
}
398421
override def apply(plan: CallPlan): Plan = {
399-
count(plan.label) += 1
422+
for ((formal, actual) <- plan.args)
423+
if (count(formal) != 0) count(actual) += 1
400424
plan
401425
}
402426
}
403-
RefCount(plan)
404-
RefCount.count
427+
refCounter(plan)
428+
refCounter.count
405429
}
406430

407431
/** Rewrite everywhere
@@ -421,13 +445,13 @@ object PatternMatcher {
421445
object hoistLabels extends PlanTransform {
422446
override def apply(plan: TestPlan): Plan =
423447
plan.onSuccess match {
424-
case lp @ LabelledPlan(sym, body) =>
448+
case lp @ LabelledPlan(sym, body, _) =>
425449
plan.onSuccess = body
426450
lp.body = plan
427451
apply(lp)
428452
case _ =>
429453
plan.onFailure match {
430-
case lp @ LabelledPlan(sym, body) =>
454+
case lp @ LabelledPlan(sym, body, _) =>
431455
plan.onFailure = body
432456
lp.body = plan
433457
apply(lp)
@@ -552,7 +576,7 @@ object PatternMatcher {
552576
* Drop all labels that are not referenced anymore after this.
553577
*/
554578
private def inlineLabelled(plan: Plan) = {
555-
val refCount = referenceCount(plan)
579+
val refCount = labelRefCount(plan)
556580
def toDrop(sym: Symbol) = labelled.contains(sym) && refCount(sym) <= 1
557581
class Inliner extends PlanTransform {
558582
override def apply(plan: LabelledPlan): Plan =
@@ -575,8 +599,17 @@ object PatternMatcher {
575599
}
576600
override def hashCode: Int = tree.hash
577601
}
602+
type SeenVars = Map[RHS, TermSymbol]
578603

579-
class Merge(seenVars: Map[RHS, Symbol]) extends PlanTransform {
604+
/** The variables known at entry to label */
605+
val seenAtLabel = mutable.HashMap[Symbol, SeenVars]()
606+
607+
/** Parameters of label; these are passed additional variables
608+
* which are known at all callsites.
609+
*/
610+
val paramsOfLabel = mutable.HashMap[Symbol, SeenVars]()
611+
612+
class Merge(seenVars: SeenVars) extends PlanTransform {
580613
override val treeMap = new TreeMap {
581614
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
582615
case tree: Ident =>
@@ -593,21 +626,49 @@ object PatternMatcher {
593626
}
594627

595628
override def apply(plan: LetPlan): Plan = {
629+
initializer(plan.sym) = apply(initializer(plan.sym))
596630
val seenVars1 =
597631
if (isPatmatGenerated(plan.sym)) {
598632
val thisRhs = new RHS(initializer(plan.sym))
599633
seenVars.get(thisRhs) match {
600634
case Some(seen) =>
601635
initializer(plan.sym) = ref(seen)
636+
seenVars
602637
case none =>
638+
seenVars.updated(thisRhs, plan.sym)
603639
}
604-
seenVars.updated(thisRhs, plan.sym)
605640
}
606641
else seenVars
607-
initializer(plan.sym) = apply(initializer(plan.sym))
608642
plan.body = new Merge(seenVars1)(plan.body)
609643
plan
610644
}
645+
646+
override def apply(plan: LabelledPlan): Plan = {
647+
seenAtLabel(plan.sym) = seenVars
648+
plan.body = apply(plan.body)
649+
val paramsMap = paramsOfLabel.getOrElse(plan.sym, Map())
650+
plan.params = paramsMap.values.toList.sortBy(_.name.toString)
651+
val seenVars1 = seenVars ++ paramsMap
652+
labelled(plan.sym) = new Merge(seenVars1)(labelled(plan.sym))
653+
plan
654+
}
655+
656+
override def apply(plan: CallPlan): Plan = {
657+
paramsOfLabel(plan.label) = paramsOfLabel.get(plan.label) match {
658+
case Some(params) =>
659+
params.filter { case (rhs, _) => seenVars.contains(rhs) }
660+
case none =>
661+
for ((rhs, _) <- seenVars if !seenAtLabel(plan.label).contains(rhs))
662+
yield (rhs, newVar(rhs.tree, Param))
663+
}
664+
plan.args =
665+
for {
666+
(rhs, actual) <- seenVars.toList
667+
formal <- paramsOfLabel(plan.label).get(rhs)
668+
}
669+
yield (formal -> actual)
670+
plan
671+
}
611672
}
612673
(new Merge(Map()))(plan)
613674
}
@@ -617,7 +678,7 @@ object PatternMatcher {
617678
* Also: hoist cases out of tests using `hoistLabelled`.
618679
*/
619680
private def inlineVars(plan: Plan): Plan = {
620-
val refCount = referenceCount(plan)
681+
val refCount = varRefCount(plan)
621682
val LetPlan(topSym, _) = plan
622683

623684
def toDrop(sym: Symbol) =
@@ -628,7 +689,8 @@ object PatternMatcher {
628689
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
629690
case tree: Ident =>
630691
val sym = tree.symbol
631-
if (toDrop(sym)) transform(initializer(sym)) else tree
692+
if (toDrop(sym)) transform(initializer(sym))
693+
else tree
632694
case _ =>
633695
super.transform(tree)
634696
}
@@ -641,6 +703,16 @@ object PatternMatcher {
641703
plan
642704
}
643705
}
706+
override def apply(plan: LabelledPlan): Plan = {
707+
plan.params = plan.params.filter(refCount(_) != 0)
708+
super.apply(plan)
709+
}
710+
override def apply(plan: CallPlan): Plan = {
711+
plan.args = plan.args
712+
.filter(formalActual => refCount(formalActual._1) != 0)
713+
.sortBy(_._1.name.toString)
714+
plan
715+
}
644716
}
645717
Inliner(plan)
646718
}
@@ -758,18 +830,21 @@ object PatternMatcher {
758830
If(emitCondition(plan).withPos(plan.pos), emit(plan.onSuccess), emit(plan.onFailure))
759831
case LetPlan(sym, body) =>
760832
seq(ValDef(sym, initializer(sym).ensureConforms(sym.info)) :: Nil, emit(body))
761-
case LabelledPlan(label, body) =>
762-
seq(DefDef(label, emit(labelled(label))) :: Nil, emit(body))
833+
case LabelledPlan(label, body, params) =>
834+
label.info = MethodType.fromSymbols(params, resultType)
835+
val labelDef = DefDef(label, Nil, params :: Nil, resultType, emit(labelled(label)))
836+
seq(labelDef :: Nil, emit(body))
763837
case CodePlan(tree) =>
764838
tree
765-
case CallPlan(label) =>
766-
ref(label).ensureApplied
839+
case CallPlan(label, args) =>
840+
ref(label).appliedToArgs(args.map { case (_, actual) => ref(actual) })
767841
}
768842
}
769843

770844
/** Pretty-print plan; used for debugging */
771845
def show(plan: Plan): String = {
772-
val refCount = referenceCount(plan)
846+
val lrefCount = labelRefCount(plan)
847+
val vrefCount = varRefCount(plan)
773848
val sb = new StringBuilder
774849
val seen = mutable.Set[Int]()
775850
def showTest(test: Test) = test match {
@@ -788,18 +863,20 @@ object PatternMatcher {
788863
showPlan(onf)
789864
case LetPlan(sym, body) =>
790865
sb.append(i"Let($sym = ${initializer(sym)}}, ${body.id})")
791-
sb.append(s", refcount = ${refCount(sym)}")
866+
sb.append(s", refcount = ${vrefCount(sym)}")
792867
showPlan(body)
793-
case LabelledPlan(label, body) =>
868+
case LabelledPlan(label, body, params) =>
794869
val labeld = labelled(label)
795-
sb.append(i"Labelled($label = ${labeld.id}, ${body.id})")
796-
sb.append(s", refcount = ${refCount(label)}")
870+
def showParam(param: Symbol) =
871+
i"$param: ${param.info}, refCount = ${vrefCount(param)}"
872+
sb.append(i"Labelled($label(${params.map(showParam)}%, %) = ${labeld.id}, ${body.id})")
873+
sb.append(s", refcount = ${lrefCount(label)}")
797874
showPlan(body)
798875
showPlan(labeld)
799876
case CodePlan(tree) =>
800877
sb.append(tree.show)
801-
case CallPlan(label) =>
802-
sb.append(s"Call($label)")
878+
case CallPlan(label, params) =>
879+
sb.append(s"Call($label(${params.map(_._2)}%, %)")
803880
}
804881
}
805882
showPlan(plan)

tests/run/reducable.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
object Test extends App {
2+
val xs = List(1, 2, 3)
3+
4+
object Cons {
5+
var count = 0
6+
7+
def unapply[T](xs: List[T]): Option[(T, List[T])] = {
8+
count += 1
9+
xs match {
10+
case x :: xs1 => Some((x, xs1))
11+
case _ => None
12+
}
13+
}
14+
}
15+
16+
val res = xs match {
17+
case Cons(0, Nil) => 1
18+
case Cons(_, Nil) => 2
19+
case Cons(0, _) => 3
20+
case Cons(1, ys) => 4
21+
}
22+
23+
assert(res == 4, res)
24+
assert(Cons.count ==1, Cons.count)
25+
}

tests/run/switches.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import annotation.switch
2+
object Test extends App {
3+
4+
val x = 3
5+
final val Y = 3
6+
7+
val x1 = x match {
8+
case 0 => 0
9+
case 1 => 1
10+
case 2 => 2
11+
case Y => 3
12+
}
13+
14+
val x2 = (x: @switch) match {
15+
case 0 => 0
16+
case 1 | 2 => 2
17+
case Y => 3
18+
case _ => 4
19+
}
20+
21+
val x3 = (x: @switch) match {
22+
case '0' if x > 0 => 0
23+
case '1' => 1
24+
case '2' => 2
25+
case '3' => 3
26+
case x => 4
27+
}
28+
29+
assert(x1 == 3)
30+
assert(x2 == 3)
31+
assert(x3 == 4)
32+
}

0 commit comments

Comments
 (0)