From 5ac8990d6160c450031a5350c8e17b4cf64ce693 Mon Sep 17 00:00:00 2001 From: poechsel Date: Tue, 20 Nov 2018 12:23:36 +0100 Subject: [PATCH 1/2] Add missing cases to treeUtils (TypeLambdaTree, Bind, Block, MatchType) --- library/src/scala/tasty/reflect/TreeUtils.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/library/src/scala/tasty/reflect/TreeUtils.scala b/library/src/scala/tasty/reflect/TreeUtils.scala index 74709fa4e356..ad5fc933a79d 100644 --- a/library/src/scala/tasty/reflect/TreeUtils.scala +++ b/library/src/scala/tasty/reflect/TreeUtils.scala @@ -16,11 +16,13 @@ trait TreeUtils def foldTree(x: X, tree: Tree)(implicit ctx: Context): X def foldTypeTree(x: X, tree: TypeOrBoundsTree)(implicit ctx: Context): X def foldCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X + def foldTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X def foldPattern(x: X, tree: Pattern)(implicit ctx: Context): X def foldTrees(x: X, trees: Iterable[Tree])(implicit ctx: Context): X = (x /: trees)(foldTree) def foldTypeTrees(x: X, trees: Iterable[TypeOrBoundsTree])(implicit ctx: Context): X = (x /: trees)(foldTypeTree) def foldCaseDefs(x: X, trees: Iterable[CaseDef])(implicit ctx: Context): X = (x /: trees)(foldCaseDef) + def foldTypeCaseDefs(x: X, trees: Iterable[TypeCaseDef])(implicit ctx: Context): X = (x /: trees)(foldTypeCaseDef) def foldPatterns(x: X, trees: Iterable[Pattern])(implicit ctx: Context): X = (x /: trees)(foldPattern) private def foldParents(x: X, trees: Iterable[TermOrTypeTree])(implicit ctx: Context): X = (x /: trees)(foldOverTermOrTypeTree) @@ -97,6 +99,13 @@ trait TreeUtils case TypeTree.Applied(tpt, args) => foldTypeTrees(foldTypeTree(x, tpt), args) case TypeTree.ByName(result) => foldTypeTree(x, result) case TypeTree.Annotated(arg, annot) => foldTree(foldTypeTree(x, arg), annot) + case TypeTree.TypeLambdaTree(typedefs, arg) => foldTrees(foldTypeTree(x, arg), typedefs) + case TypeTree.Bind(_, tbt) => foldTypeTree(x, tbt) + case TypeTree.Block(typedefs, tpt) => foldTrees(foldTypeTree(x, tpt), typedefs) + case TypeTree.MatchType(boundopt, selector, cases) => { + val bound_fold_result = boundopt.map(foldTypeTree(x, _)).getOrElse(x) + foldTypeCaseDefs(foldTypeTree(bound_fold_result, selector), cases) + } case TypeBoundsTree(lo, hi) => foldTypeTree(foldTypeTree(x, lo), hi) } @@ -104,6 +113,11 @@ trait TreeUtils case CaseDef(pat, guard, body) => foldTree(foldTrees(foldPattern(x, pat), guard), body) } + def foldOverTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X = tree match { + case TypeCaseDef(pat, body) => foldTypeTree(foldTypeTree(x, pat), body) + } + + def foldOverPattern(x: X, tree: Pattern)(implicit ctx: Context): X = tree match { case Pattern.Value(v) => foldTree(x, v) case Pattern.Bind(_, body) => foldPattern(x, body) @@ -124,16 +138,19 @@ trait TreeUtils def traverseTree(tree: Tree)(implicit ctx: Context): Unit = traverseTreeChildren(tree) def traverseTypeTree(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = traverseTypeTreeChildren(tree) def traverseCaseDef(tree: CaseDef)(implicit ctx: Context): Unit = traverseCaseDefChildren(tree) + def traverseTypeCaseDef(tree: TypeCaseDef)(implicit ctx: Context): Unit = traverseTypeCaseDefChildren(tree) def traversePattern(tree: Pattern)(implicit ctx: Context): Unit = traversePatternChildren(tree) def foldTree(x: Unit, tree: Tree)(implicit ctx: Context): Unit = traverseTree(tree) def foldTypeTree(x: Unit, tree: TypeOrBoundsTree)(implicit ctx: Context) = traverseTypeTree(tree) def foldCaseDef(x: Unit, tree: CaseDef)(implicit ctx: Context) = traverseCaseDef(tree) + def foldTypeCaseDef(x: Unit, tree: TypeCaseDef)(implicit ctx: Context) = traverseTypeCaseDef(tree) def foldPattern(x: Unit, tree: Pattern)(implicit ctx: Context) = traversePattern(tree) protected def traverseTreeChildren(tree: Tree)(implicit ctx: Context): Unit = foldOverTree((), tree) protected def traverseTypeTreeChildren(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = foldOverTypeTree((), tree) protected def traverseCaseDefChildren(tree: CaseDef)(implicit ctx: Context): Unit = foldOverCaseDef((), tree) + protected def traverseTypeCaseDefChildren(tree: TypeCaseDef)(implicit ctx: Context): Unit = foldOverTypeCaseDef((), tree) protected def traversePatternChildren(tree: Pattern)(implicit ctx: Context): Unit = foldOverPattern((), tree) } From 846196ac08029ec2297edb5fa1d7ceb4e7b98a8d Mon Sep 17 00:00:00 2001 From: poechsel Date: Tue, 20 Nov 2018 15:00:06 +0100 Subject: [PATCH 2/2] fix call order for TypeLambdaTree and Block, use of fold for MatchType --- library/src/scala/tasty/reflect/TreeUtils.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/library/src/scala/tasty/reflect/TreeUtils.scala b/library/src/scala/tasty/reflect/TreeUtils.scala index ad5fc933a79d..7f8c72b8fab4 100644 --- a/library/src/scala/tasty/reflect/TreeUtils.scala +++ b/library/src/scala/tasty/reflect/TreeUtils.scala @@ -99,13 +99,11 @@ trait TreeUtils case TypeTree.Applied(tpt, args) => foldTypeTrees(foldTypeTree(x, tpt), args) case TypeTree.ByName(result) => foldTypeTree(x, result) case TypeTree.Annotated(arg, annot) => foldTree(foldTypeTree(x, arg), annot) - case TypeTree.TypeLambdaTree(typedefs, arg) => foldTrees(foldTypeTree(x, arg), typedefs) + case TypeTree.TypeLambdaTree(typedefs, arg) => foldTypeTree(foldTrees(x, typedefs), arg) case TypeTree.Bind(_, tbt) => foldTypeTree(x, tbt) - case TypeTree.Block(typedefs, tpt) => foldTrees(foldTypeTree(x, tpt), typedefs) - case TypeTree.MatchType(boundopt, selector, cases) => { - val bound_fold_result = boundopt.map(foldTypeTree(x, _)).getOrElse(x) - foldTypeCaseDefs(foldTypeTree(bound_fold_result, selector), cases) - } + case TypeTree.Block(typedefs, tpt) => foldTypeTree(foldTrees(x, typedefs), tpt) + case TypeTree.MatchType(boundopt, selector, cases) => + foldTypeCaseDefs(foldTypeTree(boundopt.fold(x)(foldTypeTree(x, _)), selector), cases) case TypeBoundsTree(lo, hi) => foldTypeTree(foldTypeTree(x, lo), hi) } @@ -117,7 +115,6 @@ trait TreeUtils case TypeCaseDef(pat, body) => foldTypeTree(foldTypeTree(x, pat), body) } - def foldOverPattern(x: X, tree: Pattern)(implicit ctx: Context): X = tree match { case Pattern.Value(v) => foldTree(x, v) case Pattern.Bind(_, body) => foldPattern(x, body)