diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index f65154984f50..34d3a0bc1ca8 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -1638,6 +1638,7 @@ object Trees { abstract class TreeTraverser extends TreeAccumulator[Unit] { def traverse(tree: Tree)(using Context): Unit + def traverse(trees: List[Tree])(using Context) = apply((), trees) def apply(x: Unit, tree: Tree)(using Context): Unit = traverse(tree) protected def traverseChildren(tree: Tree)(using Context): Unit = foldOver((), tree) } diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index bd3f4f44984b..b89af15488f2 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -130,6 +130,7 @@ class TreeChecker extends Phase with SymTransformer { assert(ctx.typerState.constraint.domainLambdas.isEmpty, i"non-empty constraint at end of $fusedPhase: ${ctx.typerState.constraint}, ownedVars = ${ctx.typerState.ownedVars.toList}%, %") assertSelectWrapsNew(ctx.compilationUnit.tpdTree) + TreeNodeChecker.traverse(ctx.compilationUnit.tpdTree) } val checkingCtx = ctx @@ -646,4 +647,25 @@ object TreeChecker { tp } }.apply(tp0) + + /** Run some additional checks on the nodes of the trees. Specifically: + * + * - TypeTree can only appear in TypeApply args, New, Typed tpt, Closure + * tpt, SeqLiteral elemtpt, ValDef tpt, DefDef tpt, and TypeDef rhs. + */ + object TreeNodeChecker extends untpd.TreeTraverser: + import untpd._ + def traverse(tree: Tree)(using Context) = tree match + case t: TypeTree => assert(assertion = false, i"TypeTree not expected: $t") + case t @ TypeApply(fun, _targs) => traverse(fun) + case t @ New(_tpt) => + case t @ Typed(expr, _tpt) => traverse(expr) + case t @ Closure(env, meth, _tpt) => traverse(env); traverse(meth) + case t @ SeqLiteral(elems, _elemtpt) => traverse(elems) + case t @ ValDef(_, _tpt, _) => traverse(t.rhs) + case t @ DefDef(_, paramss, _tpt, _) => for params <- paramss do traverse(params); traverse(t.rhs) + case t @ TypeDef(_, _rhs) => + case t @ Template(constr, parents, self, _) => traverse(constr); traverse(parents); traverse(self); traverse(t.body) + case t => traverseChildren(t) + end traverse }