Skip to content

Commit c660d4c

Browse files
committed
Add extractor for function literals
1 parent 5db3b72 commit c660d4c

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

library/src/scala/tasty/reflect/TreeOps.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,25 @@ trait TreeOps extends Core {
574574
def tpeOpt(implicit ctx: Context): Option[Type] = kernel.Closure_tpeOpt(self)
575575
}
576576

577+
/** A lambda `(...) => ...` in the source code is represented as
578+
* a local method and a closure:
579+
*
580+
* {
581+
* def m(...) = ...
582+
* closure(m)
583+
* }
584+
*
585+
*/
586+
object Lambda {
587+
def unapply(tree: Tree)(implicit ctx: Context): Option[(List[ValDef], Term)] = tree match {
588+
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
589+
if ddef.symbol == meth.symbol =>
590+
Some(params, body)
591+
592+
case _ => None
593+
}
594+
}
595+
577596
object IsIf {
578597
/** Matches any If and returns it */
579598
def unapply(tree: Tree)(implicit ctx: Context): Option[If] = kernel.matchIf(tree)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.tasty._
3+
4+
object lib {
5+
6+
inline def assert(condition: => Boolean): Unit = ${ assertImpl('condition, '{""}) }
7+
8+
def assertImpl(cond: Expr[Boolean], clue: Expr[Any])(implicit refl: Reflection): Expr[Unit] = {
9+
import refl._
10+
import util._
11+
12+
cond.unseal.underlyingArgument match {
13+
case t @ Apply(Select(lhs, op), Lambda(param :: Nil, Apply(Select(a, "=="), b :: Nil)) :: Nil)
14+
if a.symbol == param.symbol || b.symbol == param.symbol =>
15+
'{ scala.Predef.assert($cond) }
16+
}
17+
}
18+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
import lib._
3+
4+
case class IntList(args: Int*) {
5+
def exists(f: Int => Boolean): Boolean = args.exists(f)
6+
}
7+
8+
def main(args: Array[String]): Unit = {
9+
assert(IntList(3, 5).exists(_ == 3))
10+
assert(IntList(3, 5).exists(5 == _))
11+
assert(IntList(3, 5).exists(x => x == 3))
12+
assert(IntList(3, 5).exists(x => 5 == x))
13+
}
14+
}

0 commit comments

Comments
 (0)