Skip to content

Commit 6ad8d72

Browse files
Merge pull request #3697 from dotty-staging/quote-constant-extraction
Quote constant extraction
2 parents 540371d + a6da04f commit 6ad8d72

File tree

10 files changed

+147
-45
lines changed

10 files changed

+147
-45
lines changed
Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
package dotty.tools.dotc.quoted
22

3-
import java.io.PrintStream
4-
3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.core.Contexts.Context
55
import dotty.tools.dotc.core.Phases.Phase
66

7-
/** Compiler that takes the contents of a quoted expression `expr` and produces outputs
8-
* the pretty printed code.
9-
*/
10-
class ExprDecompiler(out: PrintStream) extends ExprCompiler(null) {
7+
/** Compiler that takes the contents of a quoted expression `expr` and outputs it's tree. */
8+
class ExprDecompiler(output: tpd.Tree => Context => Unit) extends ExprCompiler(null) {
119
override def phases: List[List[Phase]] = List(
1210
List(new ExprFrontend(putInClass = false)), // Create class from Expr
13-
List(new QuotePrinter(out)) // Print all loaded classes
11+
List(new QuoteTreeOutput(output))
1412
)
13+
14+
class QuoteTreeOutput(output: tpd.Tree => Context => Unit) extends Phase {
15+
override def phaseName: String = "quotePrinter"
16+
override def run(implicit ctx: Context): Unit = output(ctx.compilationUnit.tpdTree)(ctx)
17+
}
1518
}

compiler/src/dotty/tools/dotc/quoted/QuoteDriver.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
package dotty.tools.dotc.quoted
22

3+
import dotty.tools.dotc.ast.tpd
34
import dotty.tools.dotc.Driver
45
import dotty.tools.dotc.core.Contexts.Context
56
import dotty.tools.dotc.core.StdNames._
67
import dotty.tools.io.{AbstractFile, Directory, PlainDirectory, VirtualDirectory}
78
import dotty.tools.repl.AbstractFileClassLoader
9+
import dotty.tools.dotc.printing.DecompilerPrinter
810

911
import scala.quoted.Expr
10-
import java.io.ByteArrayOutputStream
11-
import java.io.PrintStream
12-
import java.nio.charset.StandardCharsets
1312

1413
class QuoteDriver extends Driver {
14+
import tpd._
1515

1616
def run[T](expr: Expr[T], settings: Runners.RunSettings): T = {
1717
val ctx: Context = initCtx.fresh
@@ -39,18 +39,24 @@ class QuoteDriver extends Driver {
3939
}
4040

4141
def show(expr: Expr[_]): String = {
42+
def show(tree: Tree, ctx: Context): String = {
43+
val printer = new DecompilerPrinter(ctx)
44+
val pageWidth = ctx.settings.pageWidth.value(ctx)
45+
printer.toText(tree).mkString(pageWidth, false)
46+
}
47+
withTree(expr, show)
48+
}
49+
50+
def withTree[T](expr: Expr[_], f: (Tree, Context) => T): T = {
4251
val ctx: Context = initCtx.fresh
4352
ctx.settings.color.update("never")(ctx) // TODO support colored show
44-
val baos = new ByteArrayOutputStream
45-
var ps: PrintStream = null
46-
try {
47-
ps = new PrintStream(baos, true, "utf-8")
48-
49-
new ExprDecompiler(ps).newRun(ctx).compileExpr(expr)
50-
51-
new String(baos.toByteArray, StandardCharsets.UTF_8)
53+
var output: Option[T] = None
54+
def registerTree(tree: tpd.Tree)(ctx: Context): Unit = {
55+
assert(output.isEmpty)
56+
output = Some(f(tree, ctx))
5257
}
53-
finally if (ps != null) ps.close()
58+
new ExprDecompiler(registerTree).newRun(ctx).compileExpr(expr)
59+
output.getOrElse(throw new Exception("Could not extact " + expr))
5460
}
5561

5662
override def initCtx: Context = {

compiler/src/dotty/tools/dotc/quoted/QuotePrinter.scala

Lines changed: 0 additions & 20 deletions
This file was deleted.

compiler/src/dotty/tools/dotc/quoted/Runners.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package dotty.tools.dotc.quoted
22

3-
import dotty.tools.dotc.ast.Trees.Literal
4-
import dotty.tools.dotc.core.Constants.Constant
3+
import dotty.tools.dotc.ast.Trees._
4+
import dotty.tools.dotc.ast.tpd
5+
import dotty.tools.dotc.core.Constants._
6+
import dotty.tools.dotc.core.Contexts._
57
import dotty.tools.dotc.printing.RefinedPrinter
68

79
import scala.quoted.Expr
@@ -10,19 +12,34 @@ import scala.runtime.quoted._
1012

1113
/** Default runners for quoted expressions */
1214
object Runners {
15+
import tpd._
1316

1417
implicit def runner[T]: Runner[T] = new Runner[T] {
1518

1619
def run(expr: Expr[T]): T = Runners.run(expr, RunSettings())
1720

1821
def show(expr: Expr[T]): String = expr match {
1922
case expr: ConstantExpr[T] =>
20-
val ctx = new QuoteDriver().initCtx
21-
ctx.settings.color.update("never")(ctx)
23+
implicit val ctx = new QuoteDriver().initCtx
24+
ctx.settings.color.update("never")
2225
val printer = new RefinedPrinter(ctx)
2326
printer.toText(Literal(Constant(expr.value))).mkString(Int.MaxValue, false)
2427
case _ => new QuoteDriver().show(expr)
2528
}
29+
30+
def toConstantOpt(expr: Expr[T]): Option[T] = {
31+
def toConstantOpt(tree: Tree): Option[T] = tree match {
32+
case Literal(Constant(c)) => Some(c.asInstanceOf[T])
33+
case Block(Nil, e) => toConstantOpt(e)
34+
case Inlined(_, Nil, e) => toConstantOpt(e)
35+
case _ => None
36+
}
37+
expr match {
38+
case expr: ConstantExpr[T] => Some(expr.value)
39+
case _ => new QuoteDriver().withTree(expr, (tree, _) => toConstantOpt(tree))
40+
}
41+
}
42+
2643
}
2744

2845
def run[T](expr: Expr[T], settings: RunSettings): T = expr match {
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package scala.quoted
2+
3+
import scala.runtime.quoted.Runner
4+
5+
object Constant {
6+
def unapply[T](expr: Expr[T])(implicit runner: Runner[T]): Option[T] = runner.toConstantOpt(expr)
7+
}

library/src/scala/quoted/TastyExpr.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ package scala.quoted
33
import scala.runtime.quoted.Unpickler.Pickled
44

55
/** An Expr backed by a pickled TASTY tree */
6-
final case class TastyExpr[T](tasty: Pickled, args: Seq[Any]) extends Expr[T] with TastyQuoted {
6+
final class TastyExpr[T](val tasty: Pickled, val args: Seq[Any]) extends Expr[T] with TastyQuoted {
77
override def toString(): String = s"Expr(<pickled>)"
88
}

library/src/scala/quoted/TastyType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ package scala.quoted
33
import scala.runtime.quoted.Unpickler.Pickled
44

55
/** A Type backed by a pickled TASTY tree */
6-
final case class TastyType[T](tasty: Pickled, args: Seq[Any]) extends Type[T] with TastyQuoted {
6+
final class TastyType[T](val tasty: Pickled, val args: Seq[Any]) extends Type[T] with TastyQuoted {
77
override def toString(): String = s"Type(<pickled>)"
88
}

library/src/scala/runtime/quoted/Runner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ import scala.quoted.Expr
77
trait Runner[T] {
88
def run(expr: Expr[T]): T
99
def show(expr: Expr[T]): String
10+
def toConstantOpt(expr: Expr[T]): Option[T]
1011
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
3
2+
4
3+
abc
4+
null
5+
OK
6+
{
7+
{
8+
val y: Double = 3.0.*(3.0)
9+
y
10+
}
11+
}
12+
9.0
13+
{
14+
{
15+
val y: Double = 4.0.*(4.0)
16+
y
17+
}
18+
}
19+
16.0
20+
{
21+
{
22+
val y: Double = 5.0.*(5.0)
23+
y
24+
}
25+
}
26+
25.0
27+
{
28+
Test.dynamicPower(
29+
{
30+
println("foo")
31+
2
32+
}
33+
, 6.0)
34+
}
35+
foo
36+
36.0
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import scala.quoted._
2+
3+
import dotty.tools.dotc.quoted.Runners._
4+
5+
object Test {
6+
7+
def main(args: Array[String]): Unit = {
8+
(3: Expr[Int]) match { case Constant(n) => println(n) }
9+
'(4) match { case Constant(n) => println(n) }
10+
'("abc") match { case Constant(n) => println(n) }
11+
'(null) match { case Constant(n) => println(n) }
12+
13+
'(new Object) match { case Constant(n) => println(n); case _ => println("OK") }
14+
15+
16+
// 2 is a lifted constant
17+
println(power(2, 3.0).show)
18+
println(power(2, 3.0).run)
19+
20+
// n is a lifted constant
21+
val n = 2
22+
println(power(n, 4.0).show)
23+
println(power(n, 4.0).run)
24+
25+
// n is a constant in a quote
26+
println(power('(2), 5.0).show)
27+
println(power('(2), 5.0).run)
28+
29+
// n2 is clearly not a constant
30+
val n2 = '{ println("foo"); 2 }
31+
println(power(n2, 6.0).show)
32+
println(power(n2, 6.0).run)
33+
}
34+
35+
def power(n: Expr[Int], x: Expr[Double]): Expr[Double] = {
36+
n match {
37+
case Constant(n1) => powerCode(n1, x)
38+
case _ => '{ dynamicPower(~n, ~x) }
39+
}
40+
}
41+
42+
private def powerCode(n: Int, x: Expr[Double]): Expr[Double] =
43+
if (n == 0) '(1.0)
44+
else if (n == 1) x
45+
else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } }
46+
else '{ ~x * ~powerCode(n - 1, x) }
47+
48+
def dynamicPower(n: Int, x: Double): Double =
49+
if (n == 0) 1.0
50+
else if (n % 2 == 0) dynamicPower(n / 2, x * x)
51+
else x * dynamicPower(n - 1, x)
52+
}

0 commit comments

Comments
 (0)