Skip to content

Add quote pattern matching docs and simpler underlyingArgument #7697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/docs/reference/metaprogramming/macros.md
Original file line number Diff line number Diff line change
Expand Up @@ -614,4 +614,81 @@ compilation of the suspended files using the output of the previous (partial) co
In case all files are suspended due to cyclic dependencies the compilation will fail with an error.


### Pattern matching on quoted expressions

It is possible to deconstruct or extract values out of `Expr` using pattern matching.

#### scala.quoted.matching

In `scala.quoted.matching` contains object that can help extract values from `Expr`.

* `scala.quoted.matching.Const`: matches an expression a literal value and returns the value.
* `scala.quoted.matching.ExprSeq`: matches an explicit sequence of expresions and returns them. These sequences are useful to get individual `Expr[T]` out of a varargs expression of type `Expr[Seq[T]]`.
* `scala.quoted.matching.ConstSeq`: matches an explicit sequence of literal values and returns them.

These could be used in the following way to optimize any call to `sum` that has statically known values.
```scala
inline def sum(args: =>Int*): Int = ${ sumExpr('args) }
private def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr.underlyingArgument match {
case ConstSeq(args) => // args is of type Seq[Int]
Expr(args.sum) // precompute result of sum
case ExprSeq(argExprs) => // argExprs is of type Seq[Expr[Int]]
val staticSum: Int = argExprs.map {
case Const(arg) => arg
case _ => 0
}.sum
val dynamicSum: Seq[Expr[Int]] = argExprs.filter {
case Const(_) => false
case arg => true
}
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
case _ =>
'{ $argsExpr.sum }
}
```

#### Quoted patterns

Quoted pattens allow to deconstruct complex code that contains a precise structure, types or methods.
Patterns `'{ ... }` can be placed in any location where Scala expects a pattern.

For example
```scala
optimize {
sum(sum(1, a, 2), 3, b)
} // should be optimized to 6 + a + b
```

```scala
def sum(args: =>Int*): Int = args.sum
inline def optimize(arg: Int): Int = ${ optimizeExpr('arg) }
private def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match {
// Match a call to sum without any arguments
case '{ sum() } => Expr(0)
// Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument.
case '{ sum($n) } => n
// Match a call to sum and extracts all its args in an `Expr[Seq[Int]]`
case '{ sum(${ExprSeq(args)}: _*) } => sumExpr(args)
case body => body
}
private def sumExpr(args1: Seq[Expr[Int]])(given QuoteContext): Expr[Int] = {
def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match {
case '{ sum(${ExprSeq(subArgs)}: _*) } => subArgs.flatMap(flatSumArgs)
case arg => Seq(arg)
}
val args2 = args1.flatMap(flatSumArgs)
val staticSum: Int = args2.map {
case Const(arg) => arg
case _ => 0
}.sum
val dynamicSum: Seq[Expr[Int]] = args2.filter {
case Const(_) => false
case arg => true
}
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
}
```


### More details
[More details](./macros-spec.md)
20 changes: 20 additions & 0 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ package quoted {
final def matches(that: Expr[Any])(given qctx: QuoteContext): Boolean =
!scala.internal.quoted.Expr.unapply[Unit, Unit](this)(given that, false, qctx).isEmpty

/** Returns the undelying argument that was in the call before inlining.
*
* ```
* inline foo(x: Int): Int = baz(x, x)
* foo(bar())
* ```
* is inlined as
* ```
* val x = bar()
* baz(x, x)
* ```
* in this case the undelying argument of `x` will be `bar()`.
*
* Warning: Using the undelying argument directly in the expansion of a macro may change the parameter
* semantics from by-value to by-name.
*/
def underlyingArgument(given qctx: QuoteContext): Expr[T] = {
import qctx.tasty.{given, _}
this.unseal.underlyingArgument.seal.asInstanceOf[Expr[T]]
}
}

object Expr {
Expand Down
13 changes: 12 additions & 1 deletion library/src/scala/quoted/matching/ConstSeq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@ package matching
/** Literal sequence of literal constant value expressions */
object ConstSeq {

/** Matches literal sequence of literal constant value expressions */
/** Matches literal sequence of literal constant value expressions and return a sequence of values.
*
* Usage:
* ```scala
* inline def sum(args: Int*): Int = ${ sumExpr('args) }
* def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr match
* case ConstSeq(args) =>
* // args: Seq[Int]
* ...
* }
* ```
*/
def unapply[T](expr: Expr[Seq[T]])(given qctx: QuoteContext): Option[Seq[T]] = expr match {
case ExprSeq(elems) =>
elems.foldRight(Option(List.empty[T])) { (elem, acc) =>
Expand Down
13 changes: 12 additions & 1 deletion library/src/scala/quoted/matching/ExprSeq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@ package matching
/** Literal sequence of expressions */
object ExprSeq {

/** Matches a literal sequence of expressions */
/** Matches a literal sequence of expressions and return a sequence of expressions.
*
* Usage:
* ```scala
* inline def sum(args: Int*): Int = ${ sumExpr('args) }
* def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr match
* case ExprSeq(argExprs) =>
* // argExprs: Seq[Expr[Int]]
* ...
* }
* ```
*/
def unapply[T](expr: Expr[Seq[T]])(given qctx: QuoteContext): Option[Seq[Expr[T]]] = {
import qctx.tasty.{_, given}
def rec(tree: Term): Option[Seq[Expr[T]]] = tree match {
Expand Down
6 changes: 6 additions & 0 deletions tests/run-macros/quoted-matching-docs-2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
6
6
12.+(Test.a)
17
4.+(Macro_1$package.sum((Test.seq: scala.<repeated>[scala.Int])))
13
37 changes: 37 additions & 0 deletions tests/run-macros/quoted-matching-docs-2/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import scala.quoted._
import scala.quoted.matching._

def sum(args: =>Int*): Int = args.sum

inline def showOptimize(arg: Int): String = ${ showOptimizeExpr('arg) }
inline def optimize(arg: Int): Int = ${ optimizeExpr('arg) }

private def showOptimizeExpr(body: Expr[Int])(given QuoteContext): Expr[String] =
Expr(optimizeExpr(body).show)

private def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match {
// Match a call to sum without any arguments
case '{ sum() } => Expr(0)
// Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument.
case '{ sum($n) } => n
// Match a call to sum and extracts all its args in an `Expr[Seq[Int]]`
case '{ sum(${ExprSeq(args)}: _*) } => sumExpr(args)
case body => body
}

private def sumExpr(args1: Seq[Expr[Int]])(given QuoteContext): Expr[Int] = {
def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match {
case '{ sum(${ExprSeq(subArgs)}: _*) } => subArgs.flatMap(flatSumArgs)
case arg => Seq(arg)
}
val args2 = args1.flatMap(flatSumArgs)
val staticSum: Int = args2.map {
case Const(arg) => arg
case _ => 0
}.sum
val dynamicSum: Seq[Expr[Int]] = args2.filter {
case Const(_) => false
case arg => true
}
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
}
10 changes: 10 additions & 0 deletions tests/run-macros/quoted-matching-docs-2/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object Test extends App {
println(showOptimize(sum(1, 2, 3)))
println(optimize(sum(1, 2, 3)))
val a: Int = 5
println(showOptimize(sum(1, a, sum(1, 2, 3), 5)))
println(optimize(sum(1, a, sum(1, 2, 3), 5)))
val seq: Seq[Int] = Seq(1, 3, 5)
println(showOptimize(sum(1, sum(seq: _*), 3)))
println(optimize(sum(1, sum(seq: _*), 3)))
}
6 changes: 6 additions & 0 deletions tests/run-macros/quoted-matching-docs.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
6
6
10.+(Test.a)
15
args.sum[scala.Int](scala.math.Numeric.IntIsIntegral)
9
29 changes: 29 additions & 0 deletions tests/run-macros/quoted-matching-docs/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import scala.quoted._
import scala.quoted.matching._

inline def sum(args: Int*): Int = ${ sumExpr('args) }

inline def sumShow(args: Int*): String = ${ sumExprShow('args) }

private def sumExprShow(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[String] =
Expr(sumExpr(argsExpr).show)

private def sumExpr(argsExpr: Expr[Seq[Int]])(given qctx: QuoteContext): Expr[Int] = {
import qctx.tasty.{given, _}
argsExpr.underlyingArgument match {
case ConstSeq(args) => // args is of type Seq[Int]
Expr(args.sum) // precompute result of sum
case ExprSeq(argExprs) => // argExprs is of type Seq[Expr[Int]]
val staticSum: Int = argExprs.map {
case Const(arg) => arg
case _ => 0
}.sum
val dynamicSum: Seq[Expr[Int]] = argExprs.filter {
case Const(_) => false
case arg => true
}
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
case _ =>
'{ $argsExpr.sum }
}
}
10 changes: 10 additions & 0 deletions tests/run-macros/quoted-matching-docs/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object Test extends App {
println(sumShow(1, 2, 3))
println(sum(1, 2, 3))
val a: Int = 5
println(sumShow(1, a, 4, 5))
println(sum(1, a, 4, 5))
val seq: Seq[Int] = Seq(1, 3, 5)
println(sumShow(seq: _*))
println(sum(seq: _*))
}