diff --git a/docs/docs/reference/metaprogramming/macros.md b/docs/docs/reference/metaprogramming/macros.md index 5d3da4eefcff..e3c9c500495a 100644 --- a/docs/docs/reference/metaprogramming/macros.md +++ b/docs/docs/reference/metaprogramming/macros.md @@ -614,4 +614,81 @@ compilation of the suspended files using the output of the previous (partial) co In case all files are suspended due to cyclic dependencies the compilation will fail with an error. +### Pattern matching on quoted expressions + +It is possible to deconstruct or extract values out of `Expr` using pattern matching. + +#### scala.quoted.matching + +In `scala.quoted.matching` contains object that can help extract values from `Expr`. + +* `scala.quoted.matching.Const`: matches an expression a literal value and returns the value. +* `scala.quoted.matching.ExprSeq`: matches an explicit sequence of expresions and returns them. These sequences are useful to get individual `Expr[T]` out of a varargs expression of type `Expr[Seq[T]]`. +* `scala.quoted.matching.ConstSeq`: matches an explicit sequence of literal values and returns them. + +These could be used in the following way to optimize any call to `sum` that has statically known values. +```scala +inline def sum(args: =>Int*): Int = ${ sumExpr('args) } +private def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr.underlyingArgument match { + case ConstSeq(args) => // args is of type Seq[Int] + Expr(args.sum) // precompute result of sum + case ExprSeq(argExprs) => // argExprs is of type Seq[Expr[Int]] + val staticSum: Int = argExprs.map { + case Const(arg) => arg + case _ => 0 + }.sum + val dynamicSum: Seq[Expr[Int]] = argExprs.filter { + case Const(_) => false + case arg => true + } + dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg }) + case _ => + '{ $argsExpr.sum } +} +``` + +#### Quoted patterns + +Quoted pattens allow to deconstruct complex code that contains a precise structure, types or methods. +Patterns `'{ ... }` can be placed in any location where Scala expects a pattern. + +For example +```scala +optimize { + sum(sum(1, a, 2), 3, b) +} // should be optimized to 6 + a + b +``` + +```scala +def sum(args: =>Int*): Int = args.sum +inline def optimize(arg: Int): Int = ${ optimizeExpr('arg) } +private def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match { + // Match a call to sum without any arguments + case '{ sum() } => Expr(0) + // Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument. + case '{ sum($n) } => n + // Match a call to sum and extracts all its args in an `Expr[Seq[Int]]` + case '{ sum(${ExprSeq(args)}: _*) } => sumExpr(args) + case body => body +} +private def sumExpr(args1: Seq[Expr[Int]])(given QuoteContext): Expr[Int] = { + def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match { + case '{ sum(${ExprSeq(subArgs)}: _*) } => subArgs.flatMap(flatSumArgs) + case arg => Seq(arg) + } + val args2 = args1.flatMap(flatSumArgs) + val staticSum: Int = args2.map { + case Const(arg) => arg + case _ => 0 + }.sum + val dynamicSum: Seq[Expr[Int]] = args2.filter { + case Const(_) => false + case arg => true + } + dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg }) +} +``` + + +### More details [More details](./macros-spec.md) diff --git a/library/src/scala/quoted/Expr.scala b/library/src/scala/quoted/Expr.scala index 8d37819b107f..9cc62e55cd62 100644 --- a/library/src/scala/quoted/Expr.scala +++ b/library/src/scala/quoted/Expr.scala @@ -30,6 +30,26 @@ package quoted { final def matches(that: Expr[Any])(given qctx: QuoteContext): Boolean = !scala.internal.quoted.Expr.unapply[Unit, Unit](this)(given that, false, qctx).isEmpty + /** Returns the undelying argument that was in the call before inlining. + * + * ``` + * inline foo(x: Int): Int = baz(x, x) + * foo(bar()) + * ``` + * is inlined as + * ``` + * val x = bar() + * baz(x, x) + * ``` + * in this case the undelying argument of `x` will be `bar()`. + * + * Warning: Using the undelying argument directly in the expansion of a macro may change the parameter + * semantics from by-value to by-name. + */ + def underlyingArgument(given qctx: QuoteContext): Expr[T] = { + import qctx.tasty.{given, _} + this.unseal.underlyingArgument.seal.asInstanceOf[Expr[T]] + } } object Expr { diff --git a/library/src/scala/quoted/matching/ConstSeq.scala b/library/src/scala/quoted/matching/ConstSeq.scala index 742773ec9007..1d4a7769790f 100644 --- a/library/src/scala/quoted/matching/ConstSeq.scala +++ b/library/src/scala/quoted/matching/ConstSeq.scala @@ -4,7 +4,18 @@ package matching /** Literal sequence of literal constant value expressions */ object ConstSeq { - /** Matches literal sequence of literal constant value expressions */ + /** Matches literal sequence of literal constant value expressions and return a sequence of values. + * + * Usage: + * ```scala + * inline def sum(args: Int*): Int = ${ sumExpr('args) } + * def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr match + * case ConstSeq(args) => + * // args: Seq[Int] + * ... + * } + * ``` + */ def unapply[T](expr: Expr[Seq[T]])(given qctx: QuoteContext): Option[Seq[T]] = expr match { case ExprSeq(elems) => elems.foldRight(Option(List.empty[T])) { (elem, acc) => diff --git a/library/src/scala/quoted/matching/ExprSeq.scala b/library/src/scala/quoted/matching/ExprSeq.scala index fc29a771a1a5..717b5b9471de 100644 --- a/library/src/scala/quoted/matching/ExprSeq.scala +++ b/library/src/scala/quoted/matching/ExprSeq.scala @@ -4,7 +4,18 @@ package matching /** Literal sequence of expressions */ object ExprSeq { - /** Matches a literal sequence of expressions */ + /** Matches a literal sequence of expressions and return a sequence of expressions. + * + * Usage: + * ```scala + * inline def sum(args: Int*): Int = ${ sumExpr('args) } + * def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr match + * case ExprSeq(argExprs) => + * // argExprs: Seq[Expr[Int]] + * ... + * } + * ``` + */ def unapply[T](expr: Expr[Seq[T]])(given qctx: QuoteContext): Option[Seq[Expr[T]]] = { import qctx.tasty.{_, given} def rec(tree: Term): Option[Seq[Expr[T]]] = tree match { diff --git a/tests/run-macros/quoted-matching-docs-2.check b/tests/run-macros/quoted-matching-docs-2.check new file mode 100644 index 000000000000..a00f5528b2ab --- /dev/null +++ b/tests/run-macros/quoted-matching-docs-2.check @@ -0,0 +1,6 @@ +6 +6 +12.+(Test.a) +17 +4.+(Macro_1$package.sum((Test.seq: scala.[scala.Int]))) +13 diff --git a/tests/run-macros/quoted-matching-docs-2/Macro_1.scala b/tests/run-macros/quoted-matching-docs-2/Macro_1.scala new file mode 100644 index 000000000000..811936b9cb0e --- /dev/null +++ b/tests/run-macros/quoted-matching-docs-2/Macro_1.scala @@ -0,0 +1,37 @@ +import scala.quoted._ +import scala.quoted.matching._ + +def sum(args: =>Int*): Int = args.sum + +inline def showOptimize(arg: Int): String = ${ showOptimizeExpr('arg) } +inline def optimize(arg: Int): Int = ${ optimizeExpr('arg) } + +private def showOptimizeExpr(body: Expr[Int])(given QuoteContext): Expr[String] = + Expr(optimizeExpr(body).show) + +private def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match { + // Match a call to sum without any arguments + case '{ sum() } => Expr(0) + // Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument. + case '{ sum($n) } => n + // Match a call to sum and extracts all its args in an `Expr[Seq[Int]]` + case '{ sum(${ExprSeq(args)}: _*) } => sumExpr(args) + case body => body +} + +private def sumExpr(args1: Seq[Expr[Int]])(given QuoteContext): Expr[Int] = { + def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match { + case '{ sum(${ExprSeq(subArgs)}: _*) } => subArgs.flatMap(flatSumArgs) + case arg => Seq(arg) + } + val args2 = args1.flatMap(flatSumArgs) + val staticSum: Int = args2.map { + case Const(arg) => arg + case _ => 0 + }.sum + val dynamicSum: Seq[Expr[Int]] = args2.filter { + case Const(_) => false + case arg => true + } + dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg }) +} diff --git a/tests/run-macros/quoted-matching-docs-2/Test_2.scala b/tests/run-macros/quoted-matching-docs-2/Test_2.scala new file mode 100644 index 000000000000..784cdd12ae86 --- /dev/null +++ b/tests/run-macros/quoted-matching-docs-2/Test_2.scala @@ -0,0 +1,10 @@ +object Test extends App { + println(showOptimize(sum(1, 2, 3))) + println(optimize(sum(1, 2, 3))) + val a: Int = 5 + println(showOptimize(sum(1, a, sum(1, 2, 3), 5))) + println(optimize(sum(1, a, sum(1, 2, 3), 5))) + val seq: Seq[Int] = Seq(1, 3, 5) + println(showOptimize(sum(1, sum(seq: _*), 3))) + println(optimize(sum(1, sum(seq: _*), 3))) +} diff --git a/tests/run-macros/quoted-matching-docs.check b/tests/run-macros/quoted-matching-docs.check new file mode 100644 index 000000000000..9ff2df78d8c0 --- /dev/null +++ b/tests/run-macros/quoted-matching-docs.check @@ -0,0 +1,6 @@ +6 +6 +10.+(Test.a) +15 +args.sum[scala.Int](scala.math.Numeric.IntIsIntegral) +9 diff --git a/tests/run-macros/quoted-matching-docs/Macro_1.scala b/tests/run-macros/quoted-matching-docs/Macro_1.scala new file mode 100644 index 000000000000..480e8240d934 --- /dev/null +++ b/tests/run-macros/quoted-matching-docs/Macro_1.scala @@ -0,0 +1,29 @@ +import scala.quoted._ +import scala.quoted.matching._ + +inline def sum(args: Int*): Int = ${ sumExpr('args) } + +inline def sumShow(args: Int*): String = ${ sumExprShow('args) } + +private def sumExprShow(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[String] = + Expr(sumExpr(argsExpr).show) + +private def sumExpr(argsExpr: Expr[Seq[Int]])(given qctx: QuoteContext): Expr[Int] = { + import qctx.tasty.{given, _} + argsExpr.underlyingArgument match { + case ConstSeq(args) => // args is of type Seq[Int] + Expr(args.sum) // precompute result of sum + case ExprSeq(argExprs) => // argExprs is of type Seq[Expr[Int]] + val staticSum: Int = argExprs.map { + case Const(arg) => arg + case _ => 0 + }.sum + val dynamicSum: Seq[Expr[Int]] = argExprs.filter { + case Const(_) => false + case arg => true + } + dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg }) + case _ => + '{ $argsExpr.sum } + } +} \ No newline at end of file diff --git a/tests/run-macros/quoted-matching-docs/Test_2.scala b/tests/run-macros/quoted-matching-docs/Test_2.scala new file mode 100644 index 000000000000..d196181591b7 --- /dev/null +++ b/tests/run-macros/quoted-matching-docs/Test_2.scala @@ -0,0 +1,10 @@ +object Test extends App { + println(sumShow(1, 2, 3)) + println(sum(1, 2, 3)) + val a: Int = 5 + println(sumShow(1, a, 4, 5)) + println(sum(1, a, 4, 5)) + val seq: Seq[Int] = Seq(1, 3, 5) + println(sumShow(seq: _*)) + println(sum(seq: _*)) +}