Skip to content

Commit 8cccaff

Browse files
committed
Fix #3876: Implement Expr.AsFunction
1 parent 577781c commit 8cccaff

File tree

4 files changed

+103
-2
lines changed

4 files changed

+103
-2
lines changed

compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package dotty.tools.dotc.core.quoted
22

33
import dotty.tools.dotc.ast.Trees._
4-
import dotty.tools.dotc.ast.{tpd, untpd}
4+
import dotty.tools.dotc.ast.tpd
55
import dotty.tools.dotc.config.Printers._
66
import dotty.tools.dotc.core.Constants.Constant
77
import dotty.tools.dotc.core.Contexts._
88
import dotty.tools.dotc.core.Decorators._
99
import dotty.tools.dotc.core.Flags._
10+
import dotty.tools.dotc.core.NameKinds
1011
import dotty.tools.dotc.core.StdNames._
1112
import dotty.tools.dotc.core.Symbols._
1213
import dotty.tools.dotc.core.tasty.{TastyPickler, TastyPrinter, TastyString}
@@ -35,6 +36,8 @@ object PickledQuotes {
3536
def quotedToTree(expr: quoted.Quoted)(implicit ctx: Context): Tree = expr match {
3637
case expr: quoted.TastyQuoted => unpickleQuote(expr)
3738
case expr: quoted.Liftable.ConstantExpr[_] => Literal(Constant(expr.value))
39+
case expr: quoted.Expr.FunctionAppliedTo[_, _] =>
40+
functionAppliedTo(quotedToTree(expr.f), quotedToTree(expr.x))
3841
case expr: quoted.Type.TaggedPrimitive[_] =>
3942
val tpe = expr.ct match {
4043
case ClassTag.Unit => defn.UnitType
@@ -111,4 +114,28 @@ object PickledQuotes {
111114
}
112115
tree
113116
}
117+
118+
private def functionAppliedTo(f: Tree, x: Tree)(implicit ctx: Context): Tree = {
119+
val x1 = SyntheticValDef(NameKinds.UniqueName.fresh("x".toTermName), x)
120+
def x1Ref() = ref(x1.symbol)
121+
def rec(f: Tree): Tree = f match {
122+
case Block((ddef: DefDef) :: Nil, _: Closure) =>
123+
new TreeMap() {
124+
private val paramSym = ddef.vparamss.head.head.symbol
125+
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
126+
case tree: Ident if tree.symbol == paramSym => x1Ref().withPos(tree.pos)
127+
case _ => super.transform(tree)
128+
}
129+
}.transform(ddef.rhs)
130+
case Block(stats, expr) =>
131+
val applied = rec(expr)
132+
if (stats.isEmpty) applied
133+
else Block(stats, applied)
134+
case Inlined(call, bindings, expansion) =>
135+
Inlined(call, bindings, rec(expansion))
136+
case _ =>
137+
f.select(nme.apply).appliedTo(x1Ref())
138+
}
139+
Block(x1 :: Nil, rec(f))
140+
}
114141
}

library/src/scala/quoted/Expr.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ object Expr {
1313
ev.toExpr(x)
1414

1515
implicit class AsFunction[T, U](private val f: Expr[T => U]) extends AnyVal {
16-
def apply(x: Expr[T]): Expr[U] = ???
16+
def apply(x: Expr[T]): Expr[U] = new FunctionAppliedTo[T, U](f, x)
1717
}
18+
19+
final class FunctionAppliedTo[T, U] private[Expr](val f: Expr[T => U], val x: Expr[T]) extends Expr[U]
1820
}

tests/run-with-compiler/i3876.check

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
6
2+
{
3+
val x$1: Int = 3
4+
{
5+
x$1.+(x$1)
6+
}
7+
}
8+
6
9+
{
10+
val x$1: Int = 3
11+
{
12+
def f(x: Int): Int = x.+(x)
13+
f(x$1)
14+
}
15+
}
16+
6
17+
{
18+
val x$1: Int = 3
19+
{
20+
val f:
21+
Function1[Int, Int]
22+
{
23+
def apply(x: Int): Int
24+
}
25+
=
26+
{
27+
(x: Int) => x.+(x)
28+
}
29+
(f: (x: Int) => Int).apply(x$1)
30+
}
31+
}
32+
6
33+
{
34+
val x$1: Int = 3
35+
/* inlined from Test*/
36+
{
37+
x$1.+(x$1)
38+
}
39+
}

tests/run-with-compiler/i3876.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import dotty.tools.dotc.quoted.Runners._
2+
import scala.quoted._
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
val x: Expr[Int] = '(3)
6+
7+
val f: Expr[Int => Int] = '{ (x: Int) => x + x }
8+
println(f(x).run)
9+
println(f(x).show)
10+
11+
val f2: Expr[Int => Int] = '{
12+
def f(x: Int): Int = x + x
13+
f
14+
}
15+
println(f2(x).run)
16+
println(f2(x).show)
17+
18+
val f3: Expr[Int => Int] = '{
19+
val f: (x: Int) => Int = x => x + x
20+
f
21+
}
22+
println(f3(x).run)
23+
println(f3(x).show) // TODO improve printer
24+
25+
val f4: Expr[Int => Int] = '{
26+
inlineLambda
27+
}
28+
println(f4(x).run)
29+
println(f4(x).show)
30+
}
31+
32+
inline def inlineLambda: Int => Int = x => x + x
33+
}

0 commit comments

Comments
 (0)