Skip to content

Commit dc690a2

Browse files
Merge remote-tracking branch 'upstream/master'
2 parents 8ad1b58 + e1b10db commit dc690a2

File tree

21 files changed

+200
-11
lines changed

21 files changed

+200
-11
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dotty.tools.dotc
22
package tastyreflect
33

4+
import dotty.tools.dotc.ast.Trees.SeqLiteral
45
import dotty.tools.dotc.ast.{Trees, tpd, untpd}
56
import dotty.tools.dotc.ast.tpd.TreeOps
67
import dotty.tools.dotc.typer.Typer
@@ -1073,6 +1074,9 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
10731074
def Type_memberType(self: Type)(member: Symbol)(implicit ctx: Context): Type =
10741075
member.info.asSeenFrom(self, member.owner)
10751076

1077+
def Type_derivesFrom(self: Type)(cls: ClassDefSymbol)(implicit ctx: Context): Boolean =
1078+
self.derivesFrom(cls)
1079+
10761080
type ConstantType = Types.ConstantType
10771081

10781082
def matchConstantType(tpe: TypeOrBounds)(implicit ctx: Context): Option[ConstantType] = tpe match {
@@ -1794,7 +1798,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
17941798
def Definitions_Array_length: Symbol = defn.Array_length.asTerm
17951799
def Definitions_Array_update: Symbol = defn.Array_update.asTerm
17961800

1797-
def Definitions_RepeatedParamClass: Symbol = defn.RepeatedParamClass
1801+
def Definitions_RepeatedParamClass: ClassDefSymbol = defn.RepeatedParamClass
17981802

17991803
def Definitions_OptionClass: Symbol = defn.OptionClass
18001804
def Definitions_NoneModule: Symbol = defn.NoneClass.companionModule.asTerm

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,10 @@ class Typer extends Namer
580580

581581
if (untpd.isWildcardStarArg(tree)) {
582582
def typedWildcardStarArgExpr = {
583-
val tpdExpr = typedExpr(tree.expr)
583+
val ptArg =
584+
if (ctx.mode.is(Mode.QuotedPattern)) pt.underlyingIfRepeated(isJava = false)
585+
else WildcardType
586+
val tpdExpr = typedExpr(tree.expr, ptArg)
584587
tpdExpr.tpe.widenDealias match {
585588
case defn.ArrayOf(_) =>
586589
val starType = defn.ArrayType.appliedTo(WildcardType)
@@ -1960,12 +1963,17 @@ class Typer extends Namer
19601963
object splitter extends tpd.TreeMap {
19611964
val patBuf = new mutable.ListBuffer[Tree]
19621965
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
1963-
case Typed(Splice(pat), tpt) =>
1966+
case Typed(Splice(pat), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
19641967
val exprTpt = AppliedTypeTree(TypeTree(defn.QuotedExprType), tpt :: Nil)
19651968
transform(Splice(Typed(pat, exprTpt)))
19661969
case Splice(pat) =>
19671970
try patternHole(tree)
1968-
finally patBuf += pat
1971+
finally {
1972+
val patType = pat.tpe.widen
1973+
val patType1 = patType.underlyingIfRepeated(isJava = false)
1974+
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
1975+
patBuf += pat1
1976+
}
19691977
case _ =>
19701978
super.transform(tree)
19711979
}

library/src-bootstrapped/scala/internal/quoted/Matcher.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,17 @@ object Matcher {
6464

6565
(normalize(scrutinee), normalize(pattern)) match {
6666

67+
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
68+
case (IsTerm(scrutinee @ Typed(s, tpt1)), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
69+
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
70+
s.tpe <:< tpt.tpe &&
71+
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
72+
Some(Tuple1(scrutinee.seal))
73+
6774
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
6875
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
69-
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe =>
76+
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
77+
scrutinee.tpe <:< tpt.tpe =>
7078
Some(Tuple1(scrutinee.seal))
7179

7280
//
@@ -85,7 +93,7 @@ object Matcher {
8593
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
8694
treeMatches(qual1, qual2)
8795

88-
case (IsRef(_), IsRef(_, _)) if scrutinee.symbol == pattern.symbol =>
96+
case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
8997
Some(())
9098

9199
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package scala.quoted.matching
2+
3+
import scala.quoted.Expr
4+
5+
import scala.tasty.Reflection // TODO do not depend on reflection directly
6+
7+
/** Matches a literal sequence of expressions */
8+
object Repeated {
9+
10+
def unapply[T](expr: Expr[Seq[T]])(implicit reflect: Reflection): Option[Seq[Expr[T]]] = {
11+
import reflect.{Repeated => RepeatedTree, _} // TODO rename to avoid clash
12+
def repeated(tree: Term): Option[Seq[Expr[T]]] = tree match {
13+
case Typed(RepeatedTree(elems, _), _) => Some(elems.map(x => x.seal.asInstanceOf[Expr[T]]))
14+
case Block(Nil, e) => repeated(e)
15+
case Inlined(_, Nil, e) => repeated(e)
16+
case _ => None
17+
}
18+
repeated(expr.unseal)
19+
}
20+
21+
}

library/src/scala/tasty/reflect/Kernel.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,9 @@ trait Kernel {
862862

863863
def Type_memberType(self: Type)(member: Symbol)(implicit ctx: Context): Type
864864

865+
/** Is this type an instance of a non-bottom subclass of the given class `cls`? */
866+
def Type_derivesFrom(self: Type)(cls: ClassDefSymbol)(implicit ctx: Context): Boolean
867+
865868
/** A singleton type representing a known constant value */
866869
type ConstantType <: Type
867870

@@ -1454,7 +1457,7 @@ trait Kernel {
14541457
def Definitions_Array_length: Symbol
14551458
def Definitions_Array_update: Symbol
14561459

1457-
def Definitions_RepeatedParamClass: Symbol
1460+
def Definitions_RepeatedParamClass: ClassDefSymbol
14581461

14591462
def Definitions_OptionClass: Symbol
14601463
def Definitions_NoneModule: Symbol

library/src/scala/tasty/reflect/StandardDefinitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ trait StandardDefinitions extends Core {
106106
/** A dummy class symbol that is used to indicate repeated parameters
107107
* compiled by the Scala compiler.
108108
*/
109-
def RepeatedParamClass: Symbol = kernel.Definitions_RepeatedParamClass
109+
def RepeatedParamClass: ClassDefSymbol = kernel.Definitions_RepeatedParamClass
110110

111111
/** The class symbol of class `scala.Option`. */
112112
def OptionClass: Symbol = kernel.Definitions_OptionClass

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ trait TypeOrBoundsOps extends Core {
2222
def typeSymbol(implicit ctx: Context): Symbol = kernel.Type_typeSymbol(self)
2323
def isSingleton(implicit ctx: Context): Boolean = kernel.Type_isSingleton(self)
2424
def memberType(member: Symbol)(implicit ctx: Context): Type = kernel.Type_memberType(self)(member)
25+
26+
/** Is this type an instance of a non-bottom subclass of the given class `cls`? */
27+
def derivesFrom(cls: ClassDefSymbol)(implicit ctx: Context): Boolean =
28+
kernel.Type_derivesFrom(self)(cls)
29+
2530
}
2631

2732
object IsType {

tests/pos/i6253.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.quoted._
2+
import scala.tasty.Reflection
3+
object Macros {
4+
def impl(self: Expr[StringContext]) given Reflection: Expr[String] = self match {
5+
case '{ StringContext() } => '{""}
6+
case '{ StringContext($part1) } => part1
7+
case '{ StringContext($part1, $part2) } => '{ $part1 + $part2 }
8+
case '{ StringContext($parts: _*) } => '{ $parts.mkString }
9+
}
10+
}

tests/run-with-compiler/i6253-b.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Hello World
2+
Hello World
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
import scala.tasty.Reflection
5+
6+
object Macros {
7+
8+
inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}
9+
10+
private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
11+
self match {
12+
case '{ StringContext($parts: _*) } =>
13+
'{
14+
val p: Seq[String] = $parts
15+
val a: Seq[Any] = $args ++ Seq("")
16+
p.zip(a).map(_ + _.toString).mkString
17+
}
18+
case _ =>
19+
'{ "ERROR" }
20+
}
21+
}
22+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(xyz"Hello World")
7+
println(xyz"Hello ${"World"}")
8+
}
9+
10+
}

tests/run-with-compiler/i6253.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Hello World
2+
Hello World
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
import scala.tasty.Reflection
5+
6+
object Macros {
7+
8+
inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}
9+
10+
private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
11+
self match {
12+
case '{ StringContext($parts: _*) } =>
13+
'{ StringContext($parts: _*).s($args: _*) }
14+
case _ =>
15+
'{ "ERROR" }
16+
}
17+
}
18+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(xyz"Hello World")
7+
println(xyz"Hello ${"World"}")
8+
}
9+
10+
}

tests/run-with-compiler/quote-matcher-runtime.check

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ Result: Some(List())
190190

191191
Scrutinee: fs()
192192
Pattern: fs((scala.internal.Quoted.patternHole[scala.Seq[scala.Int]]: scala.<repeated>[scala.Int]))
193-
Result: Some(List(Expr()))
193+
Result: Some(List(Expr((: scala.<repeated>[scala.Int]))))
194194

195195
Scrutinee: fs((1, 2, 3: scala.<repeated>[scala.Int]))
196196
Pattern: fs((1, 2, 3: scala.<repeated>[scala.Int]))
@@ -202,7 +202,7 @@ Result: Some(List(Expr(1), Expr(2)))
202202

203203
Scrutinee: fs((1, 2, 3: scala.<repeated>[scala.Int]))
204204
Pattern: fs((scala.internal.Quoted.patternHole[scala.Seq[scala.Int]]: scala.<repeated>[scala.Int]))
205-
Result: Some(List(Expr(1, 2, 3)))
205+
Result: Some(List(Expr((1, 2, 3: scala.<repeated>[scala.Int]))))
206206

207207
Scrutinee: f2(1, 2)
208208
Pattern: f2(1, 2)
@@ -246,7 +246,7 @@ Result: Some(List(Expr("abc"), Expr("xyz")))
246246

247247
Scrutinee: scala.StringContext.apply(("abc", "xyz": scala.<repeated>[scala.Predef.String]))
248248
Pattern: scala.StringContext.apply((scala.internal.Quoted.patternHole[scala.Seq[scala.Predef.String]]: scala.<repeated>[scala.Predef.String]))
249-
Result: Some(List(Expr("abc", "xyz")))
249+
Result: Some(List(Expr(("abc", "xyz": scala.<repeated>[scala.Predef.String]))))
250250

251251
Scrutinee: {
252252
val a: scala.Int = 45
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
dlroW olleH
2+
olleHWorld
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
import scala.tasty.Reflection
5+
6+
object Macros {
7+
8+
inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}
9+
10+
private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
11+
(self, args) match {
12+
case ('{ StringContext(${Repeated(parts)}: _*) }, Repeated(args1)) =>
13+
val strParts = parts.map { case Literal(str) => str.reverse }
14+
val strArgs = args1.map { case Literal(str) => str }
15+
StringContext(strParts: _*).s(strArgs: _*).toExpr
16+
case _ => ???
17+
}
18+
19+
}
20+
21+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(xyz"Hello World")
7+
println(xyz"Hello ${"World"}")
8+
}
9+
10+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
dlroW olleH
2+
olleHWorld
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
import scala.tasty.Reflection
5+
6+
object Macros {
7+
8+
inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}
9+
10+
private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
11+
self match {
12+
case '{ StringContext(${Repeated(parts)}: _*) } =>
13+
val parts2 = parts.map(x => '{ $x.reverse }).toList.toExprOfList
14+
'{ StringContext($parts2: _*).s($args: _*) }
15+
case _ =>
16+
'{ "ERROR" }
17+
}
18+
19+
}
20+
21+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(xyz"Hello World")
7+
println(xyz"Hello ${"World"}")
8+
}
9+
10+
}

0 commit comments

Comments
 (0)