Skip to content

Commit 068fc8e

Browse files
committed
Create RetypingTreeMap that propagates types
If some node in tree is transformed changing the type, the outer node could potentially also change type. This patch implements a RetypingTreeMap that propagates those changes until types converge. Propagation is done for tree nodes that are able to compute their type based on their children: Pair, Block, If, Match, CaseDef, Try, SeqLiteral, Annotated, Select.
1 parent de2ecc7 commit 068fc8e

File tree

3 files changed

+154
-30
lines changed

3 files changed

+154
-30
lines changed

src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,36 +1109,20 @@ object Trees {
11091109
cpy.Apply(tree, transform(fun), transform(args))
11101110
case TypeApply(fun, args) =>
11111111
cpy.TypeApply(tree, transform(fun), transform(args))
1112-
case Literal(const) =>
1113-
tree
11141112
case New(tpt) =>
11151113
cpy.New(tree, transform(tpt))
1116-
case Pair(left, right) =>
1117-
cpy.Pair(tree, transform(left), transform(right))
11181114
case Typed(expr, tpt) =>
11191115
cpy.Typed(tree, transform(expr), transform(tpt))
11201116
case NamedArg(name, arg) =>
11211117
cpy.NamedArg(tree, name, transform(arg))
11221118
case Assign(lhs, rhs) =>
11231119
cpy.Assign(tree, transform(lhs), transform(rhs))
1124-
case Block(stats, expr) =>
1125-
cpy.Block(tree, transformStats(stats), transform(expr))
1126-
case If(cond, thenp, elsep) =>
1127-
cpy.If(tree, transform(cond), transform(thenp), transform(elsep))
11281120
case Closure(env, meth, tpt) =>
11291121
cpy.Closure(tree, transform(env), transform(meth), transform(tpt))
1130-
case Match(selector, cases) =>
1131-
cpy.Match(tree, transform(selector), transformSub(cases))
1132-
case CaseDef(pat, guard, body) =>
1133-
cpy.CaseDef(tree, transform(pat), transform(guard), transform(body))
11341122
case Return(expr, from) =>
11351123
cpy.Return(tree, transform(expr), transformSub(from))
1136-
case Try(block, handler, finalizer) =>
1137-
cpy.Try(tree, transform(block), transform(handler), transform(finalizer))
11381124
case Throw(expr) =>
11391125
cpy.Throw(tree, transform(expr))
1140-
case SeqLiteral(elems) =>
1141-
cpy.SeqLiteral(tree, transform(elems))
11421126
case TypeTree(original) =>
11431127
tree
11441128
case SingletonTypeTree(ref) =>
@@ -1177,12 +1161,29 @@ object Trees {
11771161
cpy.Import(tree, transform(expr), selectors)
11781162
case PackageDef(pid, stats) =>
11791163
cpy.PackageDef(tree, transformSub(pid), transformStats(stats))
1180-
case Annotated(annot, arg) =>
1181-
cpy.Annotated(tree, transform(annot), transform(arg))
11821164
case Thicket(trees) =>
11831165
val trees1 = transform(trees)
11841166
if (trees1 eq trees) tree else Thicket(trees1)
1167+
case Literal(const) =>
1168+
tree
1169+
case Pair(left, right) =>
1170+
cpy.Pair(tree, transform(left), transform(right))
1171+
case Block(stats, expr) =>
1172+
cpy.Block(tree, transformStats(stats), transform(expr))
1173+
case If(cond, thenp, elsep) =>
1174+
cpy.If(tree, transform(cond), transform(thenp), transform(elsep))
1175+
case Match(selector, cases) =>
1176+
cpy.Match(tree, transform(selector), transformSub(cases))
1177+
case CaseDef(pat, guard, body) =>
1178+
cpy.CaseDef(tree, transform(pat), transform(guard), transform(body))
1179+
case Try(block, handler, finalizer) =>
1180+
cpy.Try(tree, transform(block), transform(handler), transform(finalizer))
1181+
case SeqLiteral(elems) =>
1182+
cpy.SeqLiteral(tree, transform(elems))
1183+
case Annotated(annot, arg) =>
1184+
cpy.Annotated(tree, transform(annot), transform(arg))
11851185
}
1186+
11861187
def transformStats(trees: List[Tree])(implicit ctx: Context): List[Tree] =
11871188
transform(trees)
11881189
def transform(trees: List[Tree])(implicit ctx: Context): List[Tree] =

src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ package dotc
33
package ast
44

55
import core._
6+
import dotty.tools.dotc.transform.TypeUtils
67
import util.Positions._, Types._, Contexts._, Constants._, Names._, Flags._
78
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Symbols._
89
import CheckTrees._, Denotations._, Decorators._
910
import config.Printers._
1011
import typer.ErrorReporting._
1112

13+
import scala.annotation.tailrec
14+
1215
/** Some creators for typed trees */
1316
object tpd extends Trees.Instance[Type] with TypedTreeInfo {
1417

@@ -413,6 +416,68 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
413416
def tpes: List[Type] = xs map (_.tpe)
414417
}
415418

419+
/** RetypingTreeMap is a TreeMap that is able to propagate type changes.
420+
*
421+
* This is required when types can change during transformation,
422+
* for example if `Block(stats, expr)` is being transformed
423+
* and type of `expr` changes from `TypeRef(prefix, name)` to `TypeRef(newPrefix, name)` with different prefix, t
424+
* type of enclosing Block should also change, otherwise the whole tree would not be type-correct anymore.
425+
* see `propagateType` methods for propagation rulles.
426+
*
427+
* TreeMap does not include such logic as it assumes that types of threes do not change during transformation.
428+
*/
429+
class RetypingTreeMap extends TreeMap {
430+
431+
override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
432+
case tree@Select(qualifier, name) =>
433+
val tree1 = cpy.Select(tree, transform(qualifier), name)
434+
propagateType(tree, tree1)
435+
case tree@Pair(left, right) =>
436+
val left1 = transform(left)
437+
val right1 = transform(right)
438+
val tree1 = cpy.Pair(tree, left1, right1)
439+
propagateType(tree, tree1)
440+
case tree@Block(stats, expr) =>
441+
val stats1 = transform(stats)
442+
val expr1 = transform(expr)
443+
val tree1 = cpy.Block(tree, stats1, expr1)
444+
propagateType(tree, tree1)
445+
case tree@If(cond, thenp, elsep) =>
446+
val cond1 = transform(cond)
447+
val thenp1 = transform(thenp)
448+
val elsep1 = transform(elsep)
449+
val tree1 = cpy.If(tree, cond1, thenp1, elsep1)
450+
propagateType(tree, tree1)
451+
case tree@Match(selector, cases) =>
452+
val selector1 = transform(selector)
453+
val cases1 = transformSub(cases)
454+
val tree1 = cpy.Match(tree, selector1, cases1)
455+
propagateType(tree, tree1)
456+
case tree@CaseDef(pat, guard, body) =>
457+
val pat1 = transform(pat)
458+
val guard1 = transform(guard)
459+
val body1 = transform(body)
460+
val tree1 = cpy.CaseDef(tree, pat1, guard1, body1)
461+
propagateType(tree, tree1)
462+
case tree@Try(block, handler, finalizer) =>
463+
val expr1 = transform(block)
464+
val handler1 = transform(handler)
465+
val finalizer1 = transform(finalizer)
466+
val tree1 = cpy.Try(tree, expr1, handler1, finalizer1)
467+
propagateType(tree, tree1)
468+
case tree@SeqLiteral(elems) =>
469+
val elems1 = transform(elems)
470+
val tree1 = cpy.SeqLiteral(tree, elems1)
471+
propagateType(tree, tree1)
472+
case tree@Annotated(annot, arg) =>
473+
val annot1 = transform(annot)
474+
val arg1 = transform(arg)
475+
val tree1 = cpy.Annotated(tree, annot1, arg1)
476+
propagateType(tree, tree1)
477+
case _ => super.transform(tree)
478+
}
479+
}
480+
416481
/** A map that applies three functions together to a tree and makes sure
417482
* they are coordinated so that the result is well-typed. The functions are
418483
* @param typeMap A function from Type to type that gets applied to the
@@ -425,7 +490,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
425490
final class TreeTypeMap(
426491
val typeMap: Type => Type = IdentityTypeMap,
427492
val ownerMap: Symbol => Symbol = identity _,
428-
val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends TreeMap {
493+
val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends RetypingTreeMap {
429494

430495
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = {
431496
val tree1 = treeMap(tree)
@@ -436,10 +501,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
436501
cpy.DefDef(ddef, mods, name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(rhs))
437502
case blk @ Block(stats, expr) =>
438503
val (tmap1, stats1) = transformDefs(stats)
439-
cpy.Block(blk, stats1, tmap1.transform(expr))
504+
val expr1 = tmap1.transform(expr)
505+
val tree1 = cpy.Block(blk, stats1, expr1)
506+
propagateType(blk, tree1)
440507
case cdef @ CaseDef(pat, guard, rhs) =>
441508
val tmap = withMappedSyms(patVars(pat))
442-
cpy.CaseDef(cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs))
509+
val pat1 = tmap.transform(pat)
510+
val guard1 = tmap.transform(guard)
511+
val rhs1 = tmap.transform(rhs)
512+
val tree1 = cpy.CaseDef(tree, pat1, guard1, rhs1)
513+
propagateType(cdef, tree1)
443514
case tree1 =>
444515
super.transform(tree1)
445516
}
@@ -501,6 +572,56 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
501572
acc(Nil, tree)
502573
}
503574

575+
def propagateType(origTree: Pair, newTree: Pair)(implicit ctx: Context) = {
576+
if ((newTree eq origTree) ||
577+
((newTree.left.tpe eq origTree.left.tpe) && (newTree.right.tpe eq origTree.right.tpe))) newTree
578+
else ta.assignType(newTree, newTree.left, newTree.right)
579+
}
580+
581+
def propagateType(origTree: Block, newTree: Block)(implicit ctx: Context) = {
582+
if ((newTree eq origTree) || (newTree.expr.tpe eq origTree.expr.tpe)) newTree
583+
else ta.assignType(newTree, newTree.stats, newTree.expr)
584+
}
585+
586+
def propagateType(origTree: If, newTree: If)(implicit ctx: Context) = {
587+
if ((newTree eq origTree) ||
588+
((newTree.thenp.tpe eq origTree.thenp.tpe) && (newTree.elsep.tpe eq origTree.elsep.tpe))) newTree
589+
else ta.assignType(newTree, newTree.thenp, newTree.elsep)
590+
}
591+
592+
def propagateType(origTree: Match, newTree: Match)(implicit ctx: Context) = {
593+
if ((newTree eq origTree) || sameTypes(newTree.cases, origTree.cases)) newTree
594+
else ta.assignType(newTree, newTree.cases)
595+
}
596+
597+
def propagateType(origTree: CaseDef, newTree: CaseDef)(implicit ctx: Context) = {
598+
if ((newTree eq newTree) || (newTree.body.tpe eq origTree.body.tpe)) newTree
599+
else ta.assignType(newTree, newTree.body)
600+
}
601+
602+
def propagateType(origTree: Try, newTree: Try)(implicit ctx: Context) = {
603+
if ((newTree eq origTree) ||
604+
((newTree.expr.tpe eq origTree.expr.tpe) && (newTree.handler.tpe eq origTree.handler.tpe))) newTree
605+
else ta.assignType(newTree, newTree.expr, newTree.handler)
606+
}
607+
608+
def propagateType(origTree: SeqLiteral, newTree: SeqLiteral)(implicit ctx: Context) = {
609+
if ((newTree eq origTree) || sameTypes(newTree.elems, origTree.elems)) newTree
610+
else ta.assignType(newTree, newTree.elems)
611+
}
612+
613+
def propagateType(origTree: Annotated, newTree: Annotated)(implicit ctx: Context) = {
614+
if ((newTree eq origTree) || ((newTree.arg.tpe eq origTree.arg.tpe) && (newTree.annot eq origTree.annot))) newTree
615+
else ta.assignType(newTree, newTree.annot, newTree.arg)
616+
}
617+
618+
def propagateType(origTree: Select, newTree: Select)(implicit ctx: Context) = {
619+
if ((origTree eq newTree) || (origTree.qualifier.tpe eq newTree.qualifier.tpe)) newTree
620+
else newTree.tpe match {
621+
case tpe: NamedType => newTree.withType(tpe.derivedSelect(newTree.qualifier.tpe))
622+
case _ => newTree
623+
}
624+
}
504625
// convert a numeric with a toXXX method
505626
def primitiveConversion(tree: Tree, numericCls: Symbol)(implicit ctx: Context): Tree = {
506627
val mname = ("to" + numericCls.name).toTermName
@@ -515,6 +636,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
515636
}
516637
}
517638

639+
@tailrec
640+
def sameTypes(trees: List[tpd.Tree], trees1: List[tpd.Tree]): Boolean = {
641+
if (trees.isEmpty) trees.isEmpty
642+
else if (trees1.isEmpty) trees.isEmpty
643+
else (trees.head.tpe eq trees1.head.tpe) && sameTypes(trees.tail, trees1.tail)
644+
}
645+
518646
def evalOnce(tree: Tree)(within: Tree => Tree)(implicit ctx: Context) = {
519647
if (isIdempotentExpr(tree)) within(tree)
520648
else {
Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
package dotty.tools.dotc
22
package transform
33

4-
import core._
5-
import Types._
6-
import Contexts._
7-
import Symbols._
8-
import Decorators._
9-
import StdNames.nme
10-
import NameOps._
11-
import language.implicitConversions
4+
import dotty.tools.dotc.core.Types._
5+
6+
import scala.language.implicitConversions
127

138
object TypeUtils {
149
implicit def decorateTypeUtils(tpe: Type): TypeUtils = new TypeUtils(tpe)
10+
1511
}
1612

1713
/** A decorator that provides methods for type transformations
1814
* that are needed in the transofmer pipeline (not needed right now)
1915
*/
2016
class TypeUtils(val self: Type) extends AnyVal {
21-
import TypeUtils._
2217

2318
}

0 commit comments

Comments
 (0)