Skip to content

Commit bcd4754

Browse files
committed
Add checkTypes to support type unit testing
Implementation mostly stolen from TypeStealer.
1 parent 97c63aa commit bcd4754

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package dotty.tools
2+
3+
import org.junit.Test
4+
import org.junit.Assert.{ assertFalse, assertTrue }
5+
6+
import dotc.ast.Trees._
7+
import dotc.core.Decorators._
8+
9+
class CheckTypeTest extends DottyTest {
10+
@Test
11+
def checkTypeTest: Unit = {
12+
val source = """
13+
|class A
14+
|class B extends A
15+
""".stripMargin
16+
17+
val types = List(
18+
"A",
19+
"B",
20+
"List[_]",
21+
"List[Int]",
22+
"List[AnyRef]",
23+
"List[String]",
24+
"List[A]",
25+
"List[B]"
26+
)
27+
28+
checkTypes(source, types: _*) {
29+
case (List(a, b, lu, li, lr, ls, la, lb), context) =>
30+
implicit val ctx = context
31+
32+
assertTrue ( b <:< a)
33+
assertTrue (li <:< lu)
34+
assertFalse (li <:< lr)
35+
assertTrue (ls <:< lr)
36+
assertTrue (lb <:< la)
37+
assertFalse (la <:< lb)
38+
}
39+
}
40+
}

compiler/test/dotty/tools/DottyTest.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,25 @@ trait DottyTest extends ContextEscapeDetection {
7575
run.runContext
7676
}
7777

78+
def checkTypes(source: String, typeStrings: String*)(assertion: (List[Type], Context) => Unit) = {
79+
val dummyName = "x_x_x"
80+
val vals = typeStrings.zipWithIndex.map{case (s, x)=> s"val ${dummyName}$x: $s = ???"}.mkString("\n")
81+
val gatheredSource = s" ${source}\n object A$dummyName {$vals}"
82+
checkCompile("frontend", gatheredSource) {
83+
(tree, context) =>
84+
implicit val ctx = context
85+
val findValDef: (List[tpd.ValDef], tpd.Tree) => List[tpd.ValDef] =
86+
(acc , tree) => { tree match {
87+
case t: tpd.ValDef if t.name.startsWith(dummyName) => t :: acc
88+
case _ => acc
89+
}
90+
}
91+
val d = new tpd.DeepFolder[List[tpd.ValDef]](findValDef).foldOver(Nil, tree)
92+
val tpes = d.map(_.tpe.widen).reverse
93+
assertion(tpes, context)
94+
}
95+
}
96+
7897
def methType(names: String*)(paramTypes: Type*)(resultType: Type = defn.UnitType) =
7998
MethodType(names.toList map (_.toTermName), paramTypes.toList, resultType)
8099
}

0 commit comments

Comments
 (0)