Skip to content

Commit 948c6ea

Browse files
committed
Add While and DoWhile extractors Tasty reflect
1 parent 67b1f03 commit 948c6ea

File tree

4 files changed

+66
-56
lines changed

4 files changed

+66
-56
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/TastyImpl.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,30 @@ object TastyImpl extends scala.tasty.Tasty {
370370
}
371371

372372
object Block extends BlockExtractor {
373-
def unapply(x: Term)(implicit ctx: Context): Option[(List[Statement], Term)] = x match {
374-
case x: tpd.Block @unchecked => Some((x.stats, x.expr))
375-
case _ => None
373+
def unapply(x: Term)(implicit ctx: Context): Option[(List[Statement], Term)] = normalizedLoops(x) match {
374+
case Trees.Block(stats, expr) => Some((stats, expr))
375+
case _ => None
376+
}
377+
private def normalizedLoops(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
378+
case block: tpd.Block =>
379+
if (block.stats.size <= 1) block
380+
else {
381+
def normalizeInnerLoops(stats: List[tpd.Tree]): List[tpd.Tree] = stats match {
382+
case (x: tpd.DefDef) :: y :: xs if y.symbol.is(Flags.Label) =>
383+
tpd.Block(x :: Nil, y) :: normalizeInnerLoops(xs)
384+
case x :: xs => x :: normalizeInnerLoops(xs)
385+
case Nil => Nil
386+
}
387+
if (block.expr.symbol.is(Flags.Label)) {
388+
val stats1 = normalizeInnerLoops(block.stats.init)
389+
val normalLoop = tpd.Block(block.stats.last :: Nil, block.expr)
390+
tpd.Block(stats1, normalLoop)
391+
} else {
392+
val stats1 = normalizeInnerLoops(block.stats)
393+
tpd.cpy.Block(block)(stats1, block.expr)
394+
}
395+
}
396+
case _ => tree
376397
}
377398
}
378399

@@ -437,6 +458,28 @@ object TastyImpl extends scala.tasty.Tasty {
437458
}
438459
}
439460

461+
object While extends WhileExtractor {
462+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)] = x match {
463+
case Trees.Block((ddef: tpd.DefDef) :: Nil, expr) if expr.symbol.is(Flags.Label) && expr.symbol.name == nme.WHILE_PREFIX =>
464+
val Trees.If(cond, Trees.Block(bodyStats, _), _) = ddef.rhs
465+
Some((cond, loopBody(bodyStats)))
466+
case _ => None
467+
}
468+
}
469+
470+
object DoWhile extends DoWhileExtractor {
471+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)] = x match {
472+
case Trees.Block((ddef: tpd.DefDef) :: Nil, expr) if expr.symbol.is(Flags.Label) && expr.symbol.name == nme.DO_WHILE_PREFIX =>
473+
val Trees.Block(bodyStats, Trees.If(cond, _, _)) = ddef.rhs
474+
Some((loopBody(bodyStats), cond))
475+
case _ => None
476+
}
477+
}
478+
479+
private def loopBody(stats: List[tpd.Tree])(implicit ctx: Context): tpd.Tree = stats match {
480+
case body :: Nil => body
481+
case stats => tpd.Block(stats.init, stats.last)
482+
}
440483
}
441484

442485
// ----- CaseDef --------------------------------------------------

library/src/scala/tasty/Tasty.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,17 @@ abstract class Tasty { tasty =>
304304
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Int, Type)]
305305
}
306306

307+
val While: WhileExtractor
308+
abstract class WhileExtractor {
309+
/** Extractor for while loops. Matches `while (<cond>) <body>` and returns (<cond>, <body>) */
310+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)]
311+
}
312+
313+
val DoWhile: DoWhileExtractor
314+
abstract class DoWhileExtractor {
315+
/** Extractor for do while loops. Matches `do <body> while (<cond>)` and returns (<body>, <cond>) */
316+
def unapply(x: Term)(implicit ctx: Context): Option[(Term, Term)]
317+
}
307318
}
308319

309320
// ----- CaseDef --------------------------------------------------

library/src/scala/tasty/util/ShowSourceCode.scala

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -176,35 +176,15 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
176176
this
177177
}
178178

179-
case While(cond, stats) =>
179+
case Term.While(cond, body) =>
180180
this += "while ("
181181
printTree(cond)
182182
this += ") "
183-
stats match {
184-
case stat :: Nil =>
185-
printTree(stat)
186-
case stats =>
187-
this += "{"
188-
indented {
189-
this += lineBreak()
190-
printTrees(stats, lineBreak())
191-
}
192-
this += lineBreak() += "}"
193-
}
183+
printTree(body)
194184

195-
case DoWhile(stats, cond) =>
185+
case Term.DoWhile(body, cond) =>
196186
this += "do "
197-
stats match {
198-
case stat :: Nil =>
199-
printTree(stat)
200-
case stats =>
201-
this += "{"
202-
indented {
203-
this += lineBreak()
204-
printTrees(stats, lineBreak())
205-
}
206-
this += lineBreak() += "}"
207-
}
187+
printTree(body)
208188
this += " while ("
209189
printTree(cond)
210190
this += ")"
@@ -310,14 +290,7 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
310290
this += " = "
311291
printTree(rhs)
312292

313-
case Term.Block(stats0, expr) =>
314-
def isLoopEntryPoint(tree: Tree): Boolean = tree match {
315-
case Term.Apply(Term.Ident("while$" | "doWhile$"), _) => true
316-
case _ => false
317-
}
318-
319-
val stats = stats0.filterNot(isLoopEntryPoint)
320-
293+
case Term.Block(stats, expr) =>
321294
expr match {
322295
case Term.Lambda(_, _) =>
323296
// Decompile lambda from { def annon$(...) = ...; closure(annon$, ...)}
@@ -334,10 +307,8 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
334307
this += lineBreak()
335308
printTrees(stats, lineBreak())
336309
}
337-
if (!isLoopEntryPoint(expr)) {
338-
this += lineBreak()
339-
printTree(expr)
340-
}
310+
this += lineBreak()
311+
printTree(expr)
341312
}
342313
this += lineBreak() += "}"
343314
}
@@ -938,7 +909,6 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
938909
private def escapedString(str: String): String = str flatMap escapedChar
939910
}
940911

941-
942912
private object SpecialOp {
943913
def unapply(arg: Term)(implicit ctx: Context): Option[(String, List[Term])] = arg match {
944914
case arg@Term.Apply(fn, args) =>
@@ -951,22 +921,6 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
951921
}
952922
}
953923

954-
private object While {
955-
def unapply(arg: Tree)(implicit ctx: Context): Option[(Term, List[Statement])] = arg match {
956-
case DefDef("while$", _, _, _, Some(Term.If(cond, Term.Block(bodyStats, _), _))) => Some((cond, bodyStats))
957-
case Term.Block(List(tree), _) => unapply(tree)
958-
case _ => None
959-
}
960-
}
961-
962-
private object DoWhile {
963-
def unapply(arg: Tree)(implicit ctx: Context): Option[(List[Statement], Term)] = arg match {
964-
case DefDef("doWhile$", _, _, _, Some(Term.Block(body, Term.If(cond, _, _)))) => Some((body, cond))
965-
case Term.Block(List(tree), _) => unapply(tree)
966-
case _ => None
967-
}
968-
}
969-
970924
private object Annotation {
971925
def unapply(arg: Tree)(implicit ctx: Context): Option[(TypeTree, List[Term])] = arg match {
972926
case Term.Apply(Term.Select(Term.New(annot), "<init>", _), args) => Some((annot, args))

tests/pos/tasty/definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ object definitions {
8585
case Return(expr: Term)
8686
case Repeated(args: List[Term])
8787
case SelectOuter(from: Term, levels: Int, target: Type) // can be generated by inlining
88+
case While(cond: Term, body: Term)
89+
case DoWhile(body: Term, cond: Term)
8890
}
8991

9092
/** Trees denoting types */

0 commit comments

Comments
 (0)