Skip to content

Commit 12842bf

Browse files
committed
Add While and DoWhile extractors Tasty reflect
1 parent b369dcf commit 12842bf

File tree

4 files changed

+66
-55
lines changed

4 files changed

+66
-55
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/TastyImpl.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,30 @@ object TastyImpl extends scala.tasty.Tasty {
363363
}
364364

365365
object Block extends BlockExtractor {
366-
def unapply(x: Term)(implicit ctx: Context): Option[(List[Statement], Term)] = x match {
367-
case x: tpd.Block @unchecked => Some((x.stats, x.expr))
368-
case _ => None
366+
def unapply(x: Term)(implicit ctx: Context): Option[(List[Statement], Term)] = normalizedLoops(x) match {
367+
case Trees.Block(stats, expr) => Some((stats, expr))
368+
case _ => None
369+
}
370+
private def normalizedLoops(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
371+
case block: tpd.Block =>
372+
if (block.stats.size <= 1) block
373+
else {
374+
def normalizeInnerLoops(stats: List[tpd.Tree]): List[tpd.Tree] = stats match {
375+
case (x: tpd.DefDef) :: y :: xs if y.symbol.is(Flags.Label) =>
376+
tpd.Block(x :: Nil, y) :: normalizeInnerLoops(xs)
377+
case x :: xs => x :: normalizeInnerLoops(xs)
378+
case Nil => Nil
379+
}
380+
if (block.expr.symbol.is(Flags.Label)) {
381+
val stats1 = normalizeInnerLoops(block.stats.init)
382+
val normalLoop = tpd.Block(block.stats.last :: Nil, block.expr)
383+
tpd.Block(stats1, normalLoop)
384+
} else {
385+
val stats1 = normalizeInnerLoops(block.stats)
386+
tpd.cpy.Block(block)(stats1, block.expr)
387+
}
388+
}
389+
case _ => tree
369390
}
370391
}
371392

@@ -430,6 +451,28 @@ object TastyImpl extends scala.tasty.Tasty {
430451
}
431452
}
432453

454+
object While extends WhileExtractor {
455+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)] = x match {
456+
case Trees.Block((ddef: tpd.DefDef) :: Nil, expr) if expr.symbol.is(Flags.Label) && expr.symbol.name == nme.WHILE_PREFIX =>
457+
val Trees.If(cond, Trees.Block(bodyStats, _), _) = ddef.rhs
458+
Some((cond, loopBody(bodyStats)))
459+
case _ => None
460+
}
461+
}
462+
463+
object DoWhile extends DoWhileExtractor {
464+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)] = x match {
465+
case Trees.Block((ddef: tpd.DefDef) :: Nil, expr) if expr.symbol.is(Flags.Label) && expr.symbol.name == nme.DO_WHILE_PREFIX =>
466+
val Trees.Block(bodyStats, Trees.If(cond, _, _)) = ddef.rhs
467+
Some((loopBody(bodyStats), cond))
468+
case _ => None
469+
}
470+
}
471+
472+
private def loopBody(stats: List[tpd.Tree])(implicit ctx: Context): tpd.Tree = stats match {
473+
case body :: Nil => body
474+
case stats => tpd.Block(stats.init, stats.last)
475+
}
433476
}
434477

435478
// ----- CaseDef --------------------------------------------------

library/src/scala/tasty/Tasty.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,17 @@ abstract class Tasty { tasty =>
302302
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Int, Type)]
303303
}
304304

305+
val While: WhileExtractor
306+
abstract class WhileExtractor {
307+
/** Extractor for while loops. Matches `while (<cond>) <body>` and returns (<cond>, <body>) */
308+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)]
309+
}
310+
311+
val DoWhile: DoWhileExtractor
312+
abstract class DoWhileExtractor {
313+
/** Extractor for do while loops. Matches `do <body> while (<cond>)` and returns (<body>, <cond>) */
314+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)]
315+
}
305316
}
306317

307318
// ----- CaseDef --------------------------------------------------

library/src/scala/tasty/util/ShowSourceCode.scala

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -149,35 +149,15 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
149149
this
150150
}
151151

152-
case While(cond, stats) =>
152+
case Term.While(cond, body) =>
153153
this += "while ("
154154
printTree(cond)
155155
this += ") "
156-
stats match {
157-
case stat :: Nil =>
158-
printTree(stat)
159-
case stats =>
160-
this += "{"
161-
indented {
162-
this += lineBreak()
163-
printTrees(stats, lineBreak())
164-
}
165-
this += lineBreak() += "}"
166-
}
156+
printTree(body)
167157

168-
case DoWhile(stats, cond) =>
158+
case Term.DoWhile(body, cond) =>
169159
this += "do "
170-
stats match {
171-
case stat :: Nil =>
172-
printTree(stat)
173-
case stats =>
174-
this += "{"
175-
indented {
176-
this += lineBreak()
177-
printTrees(stats, lineBreak())
178-
}
179-
this += lineBreak() += "}"
180-
}
160+
printTree(body)
181161
this += " while ("
182162
printTree(cond)
183163
this += ")"
@@ -262,14 +242,7 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
262242
this += " = "
263243
printTree(rhs)
264244

265-
case Term.Block(stats0, expr) =>
266-
def isLoopEntryPoint(tree: Tree): Boolean = tree match {
267-
case Term.Apply(Term.Ident("while$" | "doWhile$"), _) => true
268-
case _ => false
269-
}
270-
271-
val stats = stats0.filterNot(isLoopEntryPoint)
272-
245+
case Term.Block(stats, expr) =>
273246
expr match {
274247
case Term.Lambda(_, _) =>
275248
// Decompile lambda from { def annon$(...) = ...; closure(annon$, ...)}
@@ -286,10 +259,8 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
286259
this += lineBreak()
287260
printTrees(stats, lineBreak())
288261
}
289-
if (!isLoopEntryPoint(expr)) {
290-
this += lineBreak()
291-
printTree(expr)
292-
}
262+
this += lineBreak()
263+
printTree(expr)
293264
}
294265
this += lineBreak() += "}"
295266
}
@@ -765,22 +736,6 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
765736
}
766737
}
767738

768-
private object While {
769-
def unapply(arg: Tree)(implicit ctx: Context): Option[(Term, List[Statement])] = arg match {
770-
case DefDef("while$", _, _, _, Some(Term.If(cond, Term.Block(bodyStats, _), _))) => Some((cond, bodyStats))
771-
case Term.Block(List(tree), _) => unapply(tree)
772-
case _ => None
773-
}
774-
}
775-
776-
private object DoWhile {
777-
def unapply(arg: Tree)(implicit ctx: Context): Option[(List[Statement], Term)] = arg match {
778-
case DefDef("doWhile$", _, _, _, Some(Term.Block(body, Term.If(cond, _, _)))) => Some((body, cond))
779-
case Term.Block(List(tree), _) => unapply(tree)
780-
case _ => None
781-
}
782-
}
783-
784739
// TODO Provide some of these in scala.tasty.Tasty.scala and implement them using checks on symbols for performance
785740
private object Types {
786741

tests/pos/tasty/definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ object definitions {
7373
case Return(expr: Term)
7474
case Repeated(args: List[Term])
7575
case SelectOuter(from: Term, levels: Int, target: Type) // can be generated by inlining
76+
case While(cond: Term, body: Term)
77+
case DoWhile(body: Term, cond: Term)
7678
}
7779

7880
/** Trees denoting types */

0 commit comments

Comments
 (0)