Skip to content

Commit 6f9cd74

Browse files
committed
Fix #6253: Handle repeated args in quoted patterns
* Handle repeated args in quoted patterns in typer * Handle repeated args in quoted patterns in Matcher * Add Repeated extractor to get Seq[Expr[T]] from a Expr[Seq[T]]
1 parent 449008e commit 6f9cd74

File tree

10 files changed

+105
-4
lines changed

10 files changed

+105
-4
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,10 @@ class Typer extends Namer
580580

581581
if (untpd.isWildcardStarArg(tree)) {
582582
def typedWildcardStarArgExpr = {
583-
val tpdExpr = typedExpr(tree.expr)
583+
val ptArg =
584+
if (ctx.mode.is(Mode.QuotedPattern)) pt.subst(defn.RepeatedParamClass :: Nil, defn.SeqType :: Nil)
585+
else WildcardType
586+
val tpdExpr = typedExpr(tree.expr, ptArg)
584587
tpdExpr.tpe.widenDealias match {
585588
case defn.ArrayOf(_) =>
586589
val starType = defn.ArrayType.appliedTo(WildcardType)
@@ -1960,12 +1963,15 @@ class Typer extends Namer
19601963
object splitter extends tpd.TreeMap {
19611964
val patBuf = new mutable.ListBuffer[Tree]
19621965
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
1963-
case Typed(Splice(pat), tpt) =>
1966+
case Typed(Splice(pat), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
19641967
val exprTpt = AppliedTypeTree(TypeTree(defn.QuotedExprType), tpt :: Nil)
19651968
transform(Splice(Typed(pat, exprTpt)))
19661969
case Splice(pat) =>
19671970
try patternHole(tree)
1968-
finally patBuf += pat
1971+
finally {
1972+
val pat1 = pat.subst(defn.RepeatedParamClass :: Nil, defn.SeqClass :: Nil)
1973+
patBuf += pat1
1974+
}
19691975
case _ =>
19701976
super.transform(tree)
19711977
}

library/src-bootstrapped/scala/internal/quoted/Matcher.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ object Matcher {
6767
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
6868
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
6969
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe =>
70+
// scrutinee match {
71+
// case Repeated(args, _) =>
72+
// }
7073
Some(Tuple1(scrutinee.seal))
7174

7275
//
@@ -85,7 +88,7 @@ object Matcher {
8588
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
8689
treeMatches(qual1, qual2)
8790

88-
case (IsRef(_), IsRef(_, _)) if scrutinee.symbol == pattern.symbol =>
91+
case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
8992
Some(())
9093

9194
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package scala.quoted.matching
2+
3+
import scala.quoted.Expr
4+
5+
import scala.tasty.Reflection // TODO do not depend on reflection directly
6+
7+
/** Matches a literal sequence of expressions */
8+
object Repeated {
9+
10+
def unapply[T](expr: Expr[Seq[T]])(implicit reflect: Reflection): Option[Seq[Expr[T]]] = {
11+
import reflect.{Repeated => RepeatedTree, _} // TODO rename to avoid clash
12+
def repeated(tree: Term): Option[Seq[Expr[T]]] = tree match {
13+
case RepeatedTree(elems, _) => Some(elems.map(x => x.seal.asInstanceOf[Expr[T]]))
14+
case Block(Nil, e) => repeated(e)
15+
case Inlined(_, Nil, e) => repeated(e)
16+
case _ => None
17+
}
18+
repeated(expr.unseal)
19+
}
20+
21+
}

tests/pos/i6253.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.quoted._
2+
import scala.tasty.Reflection
3+
object Macros {
4+
def impl(self: Expr[StringContext]) given Reflection: Expr[String] = self match {
5+
case '{ StringContext() } => '{""}
6+
case '{ StringContext($part1) } => part1
7+
case '{ StringContext($part1, $part2) } => '{ $part1 + $part2 }
8+
case '{ StringContext($parts: _*) } => '{ $parts.mkString }
9+
}
10+
}

tests/run-with-compiler/i6253.check

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
import scala.tasty.Reflection
5+
6+
object Macros {
7+
8+
inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}
9+
10+
private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
11+
self match {
12+
case '{ StringContext($parts: _*) } =>
13+
'{ StringContext($parts: _*).s($args: _*) }
14+
case _ =>
15+
'{ "ERROR" }
16+
}
17+
}
18+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(xyz"Hello World")
7+
println(xyz"Hello ${"World"}")
8+
}
9+
10+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
dlroW olleH
2+
olleHWorld
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
import scala.tasty.Reflection
5+
6+
object Macros {
7+
8+
inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}
9+
10+
private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
11+
self match {
12+
case '{ StringContext(${Repeated(parts)}: _*) } =>
13+
val parts2 = parts.map(x => '{ $x.reverse }).toList.toExprOfList
14+
'{ StringContext($parts2: _*).s($args: _*) }
15+
case _ =>
16+
'{ "ERROR" }
17+
}
18+
19+
}
20+
21+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(xyz"Hello World")
7+
println(xyz"Hello ${"World"}")
8+
}
9+
10+
}

0 commit comments

Comments
 (0)