Skip to content

Commit 055132d

Browse files
committed
Implement quoted Lambda extractor
This extractor returns the lifted representation of the body of an explicit lambda
1 parent 3937029 commit 055132d

File tree

10 files changed

+111
-39
lines changed

10 files changed

+111
-39
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,44 +2051,41 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20512051
}}
20522052
val argVals = argVals0.reverse
20532053
val argRefs = argRefs0.reverse
2054-
def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match {
2055-
case Typed(expr, tpt) =>
2056-
// we need to retain any type ascriptions we see and:
2057-
// a) if we succeed, ascribe the result type of the ascription to the inlined body
2058-
// b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
2059-
// note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
2060-
rec(expr, topAscription.orElse(Some(tpt)))
2061-
case Inlined(call, bindings, expansion) =>
2062-
// this case must go before closureDef to avoid dropping the inline node
2063-
cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription))
2064-
case closureDef(ddef) =>
2065-
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
2066-
val paramToVals = paramSyms.zip(argRefs).toMap
2067-
val result = new TreeTypeMap(
2068-
oldOwners = ddef.symbol :: Nil,
2069-
newOwners = ctx.owner :: Nil,
2070-
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
2071-
).transform(ddef.rhs)
2072-
topAscription match {
2073-
case Some(tpt) =>
2074-
// we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
2075-
val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType]
2076-
// result might contain paramrefs, so we substitute them with arg termrefs
2077-
val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
2078-
Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
2079-
case None =>
2080-
result
2081-
}
2082-
case tpd.Block(stats, expr) =>
2083-
seq(stats, rec(expr, topAscription)).withSpan(fn.span)
2084-
case _ =>
2085-
val maybeAscribed = topAscription match {
2086-
case Some(tpt) => Typed(fn, tpt).withSpan(fn.span)
2087-
case None => fn
2088-
}
2089-
maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
2054+
val reducedBody = lambdaExtractor(fn) match {
2055+
case Some(body) => body(argRefs)
2056+
case None => fn.select(nme.apply).appliedToArgs(argRefs)
2057+
}
2058+
seq(argVals, reducedBody).withSpan(fn.span)
2059+
}
2060+
2061+
def lambdaExtractor(fn: Term)(using ctx: Context): Option[List[Term] => Term] = {
2062+
def rec(fn: Term, transformBody: Term => Term): Option[List[Term] => Term] = {
2063+
fn match {
2064+
case Inlined(call, bindings, expansion) =>
2065+
// this case must go before closureDef to avoid dropping the inline node
2066+
rec(expansion, cpy.Inlined(fn)(call, bindings, _))
2067+
case Typed(expr, tpt) =>
2068+
val resTpe = tpt.tpe.dropDependentRefinement.argInfos.last
2069+
rec(expr, Typed(_, TypeTree(resTpe).withSpan(tpt.span)))
2070+
case closureDef(ddef) =>
2071+
def replace(body: Term, argRefs: List[Term]): Term = {
2072+
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
2073+
val paramToVals = paramSyms.zip(argRefs).toMap
2074+
new TreeTypeMap(
2075+
oldOwners = ddef.symbol :: Nil,
2076+
newOwners = ctx.owner :: Nil,
2077+
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
2078+
).transform(body)
2079+
}
2080+
Some(argRefs => replace(transformBody(ddef.rhs), argRefs))
2081+
case Block(stats, expr) =>
2082+
// this case must go after closureDef to avoid matching the closure
2083+
rec(expr, cpy.Block(fn)(stats, _))
2084+
case _ =>
2085+
None
2086+
}
20902087
}
2091-
seq(argVals, rec(fn, None))
2088+
rec(fn, identity)
20922089
}
20932090

20942091
/////////////
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package scala.quoted
2+
package matching
3+
4+
/** Lambda expression extractor */
5+
object Lambda {
6+
7+
/** `case Lambda(body)` matche a lambda and extract the body.
8+
* As the body may (will) contain references to the paramter, `body` is a function that recieves those arguments as `Expr`.
9+
* Once this function is applied the result will be the body of the lambda with all references to the parameters replaced.
10+
* If `body` is of type `(T1, T2, ...) => R` then body will be of type `(Expr[T1], Expr[T2], ...) => Expr[R]`.
11+
*
12+
* ```
13+
* '{ (x: Int) => println(x) } match
14+
* case Lambda(body) =>
15+
* // where `body` is: (x: Expr[Int]) => '{ println($x) }
16+
* body('{3}) // returns '{ println(3) }
17+
* ```
18+
*/
19+
def unapply[F, Args <: Tuple, Res, G](expr: Expr[F])(using qctx: QuoteContext, tf: TupledFunction[F, Args => Res], tg: TupledFunction[G, Tuple.Map[Args, Expr] => Expr[Res]]): Option[/*QuoteContext ?=>*/ G] = {
20+
import qctx.tasty.{_, given _ }
21+
qctx.tasty.internal.lambdaExtractor(expr.unseal).map { fn =>
22+
def f(args: Tuple.Map[Args, Expr]): Expr[Res] =
23+
fn(args.toArray.map(_.asInstanceOf[Expr[Any]].unseal).toList).seal.asInstanceOf[Expr[Res]]
24+
tg.untupled(f)
25+
}
26+
27+
}
28+
29+
}

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,4 +1543,6 @@ trait CompilerInterface {
15431543
*/
15441544
def betaReduce(f: Term, args: List[Term])(using ctx: Context): Term
15451545

1546+
def lambdaExtractor(term: Term)(using ctx: Context): Option[List[Term] => Term]
1547+
15461548
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
scala.Predef.identity[scala.Int](1)
2+
1
3+
{
4+
scala.Predef.println(1)
5+
1
6+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def test(inline f: Int => Int): String = ${ impl('f) }
5+
6+
def impl(using QuoteContext)(f: Expr[Int => Int]): Expr[String] = {
7+
Expr(f match {
8+
case Lambda(body) => body('{1}).show
9+
case _ => f.show
10+
})
11+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
@main def Test = {
3+
println(test(identity))
4+
println(test(x => x))
5+
println(test(x => { println(x); x }))
6+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
1.+(2)
2+
{
3+
scala.Predef.println(1)
4+
2
5+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def test(inline f: (Int, Int) => Int): String = ${ impl('f) }
5+
6+
def impl(using QuoteContext)(f: Expr[(Int, Int) => Int]): Expr[String] = {
7+
Expr(f match {
8+
case Lambda(body) => body('{1}, '{2}).show
9+
case _ => f.show
10+
})
11+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
@main def Test = {
3+
println(test((x, y) => x + y))
4+
println(test((x, y) => { println(x); y }))
5+
}

tests/run-staging/i3876-c.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66

77
(f: scala.Function1[scala.Int, scala.Int] {
88
def apply(x: scala.Int): scala.Int
9-
}).apply(3)
10-
}
9+
})
10+
}.apply(3)

0 commit comments

Comments
 (0)