Skip to content

Fix #6253: Handle repeated args in quoted patterns #6254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dotty.tools.dotc
package tastyreflect

import dotty.tools.dotc.ast.Trees.SeqLiteral
import dotty.tools.dotc.ast.{Trees, tpd, untpd}
import dotty.tools.dotc.ast.tpd.TreeOps
import dotty.tools.dotc.typer.Typer
Expand Down Expand Up @@ -1055,6 +1056,9 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
def Type_memberType(self: Type)(member: Symbol)(implicit ctx: Context): Type =
member.info.asSeenFrom(self, member.owner)

def Type_derivesFrom(self: Type)(cls: ClassDefSymbol)(implicit ctx: Context): Boolean =
self.derivesFrom(cls)

type ConstantType = Types.ConstantType

def matchConstantType(tpe: TypeOrBounds)(implicit ctx: Context): Option[ConstantType] = tpe match {
Expand Down Expand Up @@ -1774,7 +1778,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
def Definitions_Array_length: Symbol = defn.Array_length.asTerm
def Definitions_Array_update: Symbol = defn.Array_update.asTerm

def Definitions_RepeatedParamClass: Symbol = defn.RepeatedParamClass
def Definitions_RepeatedParamClass: ClassDefSymbol = defn.RepeatedParamClass

def Definitions_OptionClass: Symbol = defn.OptionClass
def Definitions_NoneModule: Symbol = defn.NoneClass.companionModule.asTerm
Expand Down
14 changes: 11 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,10 @@ class Typer extends Namer

if (untpd.isWildcardStarArg(tree)) {
def typedWildcardStarArgExpr = {
val tpdExpr = typedExpr(tree.expr)
val ptArg =
if (ctx.mode.is(Mode.QuotedPattern)) pt.underlyingIfRepeated(isJava = false)
else WildcardType
val tpdExpr = typedExpr(tree.expr, ptArg)
tpdExpr.tpe.widenDealias match {
case defn.ArrayOf(_) =>
val starType = defn.ArrayType.appliedTo(WildcardType)
Expand Down Expand Up @@ -1960,12 +1963,17 @@ class Typer extends Namer
object splitter extends tpd.TreeMap {
val patBuf = new mutable.ListBuffer[Tree]
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
case Typed(Splice(pat), tpt) =>
case Typed(Splice(pat), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
val exprTpt = AppliedTypeTree(TypeTree(defn.QuotedExprType), tpt :: Nil)
transform(Splice(Typed(pat, exprTpt)))
case Splice(pat) =>
try patternHole(tree)
finally patBuf += pat
finally {
val patType = pat.tpe.widen
val patType1 = patType.underlyingIfRepeated(isJava = false)
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
patBuf += pat1
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change here and above looks dubious to me. @odersky could you have a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure a global substitution is the right thing here. I fear it would also affect nested varargs methods that would then become Seq methods. I'd do instead:

val patType = pat.tpe.widen
val patType1 = patType.underlyingIfRepeated(isJava = false)
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use this code now and also updated another place that had the subst to use underlyingIfRepeated.

case _ =>
super.transform(tree)
}
Expand Down
12 changes: 10 additions & 2 deletions library/src-bootstrapped/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,17 @@ object Matcher {

(normalize(scrutinee), normalize(pattern)) match {

// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
case (IsTerm(scrutinee @ Typed(s, tpt1)), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
s.tpe <:< tpt.tpe &&
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
Some(Tuple1(scrutinee.seal))

// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe =>
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
scrutinee.tpe <:< tpt.tpe =>
Some(Tuple1(scrutinee.seal))

//
Expand All @@ -85,7 +93,7 @@ object Matcher {
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
treeMatches(qual1, qual2)

case (IsRef(_), IsRef(_, _)) if scrutinee.symbol == pattern.symbol =>
case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
Some(())

case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
Expand Down
21 changes: 21 additions & 0 deletions library/src/scala/quoted/matching/Repeated.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package scala.quoted.matching

import scala.quoted.Expr

import scala.tasty.Reflection // TODO do not depend on reflection directly

/** Matches a literal sequence of expressions */
object Repeated {

def unapply[T](expr: Expr[Seq[T]])(implicit reflect: Reflection): Option[Seq[Expr[T]]] = {
import reflect.{Repeated => RepeatedTree, _} // TODO rename to avoid clash
def repeated(tree: Term): Option[Seq[Expr[T]]] = tree match {
case Typed(RepeatedTree(elems, _), _) => Some(elems.map(x => x.seal.asInstanceOf[Expr[T]]))
case Block(Nil, e) => repeated(e)
case Inlined(_, Nil, e) => repeated(e)
case _ => None
}
repeated(expr.unseal)
}

}
5 changes: 4 additions & 1 deletion library/src/scala/tasty/reflect/Kernel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,9 @@ trait Kernel {

def Type_memberType(self: Type)(member: Symbol)(implicit ctx: Context): Type

/** Is this type an instance of a non-bottom subclass of the given class `cls`? */
def Type_derivesFrom(self: Type)(cls: ClassDefSymbol)(implicit ctx: Context): Boolean

/** A singleton type representing a known constant value */
type ConstantType <: Type

Expand Down Expand Up @@ -1434,7 +1437,7 @@ trait Kernel {
def Definitions_Array_length: Symbol
def Definitions_Array_update: Symbol

def Definitions_RepeatedParamClass: Symbol
def Definitions_RepeatedParamClass: ClassDefSymbol

def Definitions_OptionClass: Symbol
def Definitions_NoneModule: Symbol
Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/tasty/reflect/StandardDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ trait StandardDefinitions extends Core {
/** A dummy class symbol that is used to indicate repeated parameters
* compiled by the Scala compiler.
*/
def RepeatedParamClass: Symbol = kernel.Definitions_RepeatedParamClass
def RepeatedParamClass: ClassDefSymbol = kernel.Definitions_RepeatedParamClass

/** The class symbol of class `scala.Option`. */
def OptionClass: Symbol = kernel.Definitions_OptionClass
Expand Down
5 changes: 5 additions & 0 deletions library/src/scala/tasty/reflect/TypeOrBoundsOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ trait TypeOrBoundsOps extends Core {
def typeSymbol(implicit ctx: Context): Symbol = kernel.Type_typeSymbol(self)
def isSingleton(implicit ctx: Context): Boolean = kernel.Type_isSingleton(self)
def memberType(member: Symbol)(implicit ctx: Context): Type = kernel.Type_memberType(self)(member)

/** Is this type an instance of a non-bottom subclass of the given class `cls`? */
def derivesFrom(cls: ClassDefSymbol)(implicit ctx: Context): Boolean =
kernel.Type_derivesFrom(self)(cls)

}

object IsType {
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/i6253.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.quoted._
import scala.tasty.Reflection
object Macros {
def impl(self: Expr[StringContext]) given Reflection: Expr[String] = self match {
case '{ StringContext() } => '{""}
case '{ StringContext($part1) } => part1
case '{ StringContext($part1, $part2) } => '{ $part1 + $part2 }
case '{ StringContext($parts: _*) } => '{ $parts.mkString }
}
}
2 changes: 2 additions & 0 deletions tests/run-with-compiler/i6253-b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Hello World
Hello World
22 changes: 22 additions & 0 deletions tests/run-with-compiler/i6253-b/quoted_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import scala.quoted._
import scala.quoted.matching._

import scala.tasty.Reflection

object Macros {

inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}

private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
self match {
case '{ StringContext($parts: _*) } =>
'{
val p: Seq[String] = $parts
val a: Seq[Any] = $args ++ Seq("")
p.zip(a).map(_ + _.toString).mkString
}
case _ =>
'{ "ERROR" }
}
}
}
10 changes: 10 additions & 0 deletions tests/run-with-compiler/i6253-b/quoted_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Macros._

object Test {

def main(args: Array[String]): Unit = {
println(xyz"Hello World")
println(xyz"Hello ${"World"}")
}

}
2 changes: 2 additions & 0 deletions tests/run-with-compiler/i6253.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Hello World
Hello World
18 changes: 18 additions & 0 deletions tests/run-with-compiler/i6253/quoted_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.quoted._
import scala.quoted.matching._

import scala.tasty.Reflection

object Macros {

inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}

private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
self match {
case '{ StringContext($parts: _*) } =>
'{ StringContext($parts: _*).s($args: _*) }
case _ =>
'{ "ERROR" }
}
}
}
10 changes: 10 additions & 0 deletions tests/run-with-compiler/i6253/quoted_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Macros._

object Test {

def main(args: Array[String]): Unit = {
println(xyz"Hello World")
println(xyz"Hello ${"World"}")
}

}
6 changes: 3 additions & 3 deletions tests/run-with-compiler/quote-matcher-runtime.check
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ Result: Some(List())

Scrutinee: fs()
Pattern: fs((scala.internal.Quoted.patternHole[scala.Seq[scala.Int]]: scala.<repeated>[scala.Int]))
Result: Some(List(Expr()))
Result: Some(List(Expr((: scala.<repeated>[scala.Int]))))

Scrutinee: fs((1, 2, 3: scala.<repeated>[scala.Int]))
Pattern: fs((1, 2, 3: scala.<repeated>[scala.Int]))
Expand All @@ -202,7 +202,7 @@ Result: Some(List(Expr(1), Expr(2)))

Scrutinee: fs((1, 2, 3: scala.<repeated>[scala.Int]))
Pattern: fs((scala.internal.Quoted.patternHole[scala.Seq[scala.Int]]: scala.<repeated>[scala.Int]))
Result: Some(List(Expr(1, 2, 3)))
Result: Some(List(Expr((1, 2, 3: scala.<repeated>[scala.Int]))))

Scrutinee: f2(1, 2)
Pattern: f2(1, 2)
Expand Down Expand Up @@ -246,7 +246,7 @@ Result: Some(List(Expr("abc"), Expr("xyz")))

Scrutinee: scala.StringContext.apply(("abc", "xyz": scala.<repeated>[scala.Predef.String]))
Pattern: scala.StringContext.apply((scala.internal.Quoted.patternHole[scala.Seq[scala.Predef.String]]: scala.<repeated>[scala.Predef.String]))
Result: Some(List(Expr("abc", "xyz")))
Result: Some(List(Expr(("abc", "xyz": scala.<repeated>[scala.Predef.String]))))

Scrutinee: {
val a: scala.Int = 45
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dlroW olleH
olleHWorld
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import scala.quoted._
import scala.quoted.matching._

import scala.tasty.Reflection

object Macros {

inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}

private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
(self, args) match {
case ('{ StringContext(${Repeated(parts)}: _*) }, Repeated(args1)) =>
val strParts = parts.map { case Literal(str) => str.reverse }
val strArgs = args1.map { case Literal(str) => str }
StringContext(strParts: _*).s(strArgs: _*).toExpr
case _ => ???
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Macros._

object Test {

def main(args: Array[String]): Unit = {
println(xyz"Hello World")
println(xyz"Hello ${"World"}")
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dlroW olleH
olleHWorld
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import scala.quoted._
import scala.quoted.matching._

import scala.tasty.Reflection

object Macros {

inline def (self: => StringContext) xyz(args: => String*): String = ${impl('self, 'args)}

private def impl(self: Expr[StringContext], args: Expr[Seq[String]])(implicit reflect: Reflection): Expr[String] = {
self match {
case '{ StringContext(${Repeated(parts)}: _*) } =>
val parts2 = parts.map(x => '{ $x.reverse }).toList.toExprOfList
'{ StringContext($parts2: _*).s($args: _*) }
case _ =>
'{ "ERROR" }
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Macros._

object Test {

def main(args: Array[String]): Unit = {
println(xyz"Hello World")
println(xyz"Hello ${"World"}")
}

}