Skip to content

Commit 3b152c8

Browse files
committed
Add support for pattern matching on definition identifiers
Introduces `Binding[T]` which can be used to match a check is an `Expr` is a reference to some other binding defined in scope. ```scala case '{ val $x: Int = ($body: Int) } => // where x: Binding[Int] case '{ ($x: Int) => ($body: Int) } => // where x: Binding[Int] case Binding(b) => // where b: Binding[Int] ```
1 parent 8e323b1 commit 3b152c8

File tree

19 files changed

+417
-27
lines changed

19 files changed

+417
-27
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ object desugar {
139139
* def x: Int = expr
140140
* def x_=($1: <TypeTree()>): Unit = ()
141141
*/
142-
def valDef(vdef: ValDef)(implicit ctx: Context): Tree = {
142+
def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = {
143+
val vdef = transformQuotedPatternName(vdef0)
143144
val ValDef(name, tpt, rhs) = vdef
144145
val mods = vdef.mods
145146
val setterNeeded =
@@ -164,6 +165,14 @@ object desugar {
164165
else vdef
165166
}
166167

168+
def transformQuotedPatternName(vdef: ValDef)(implicit ctx: Context): ValDef = {
169+
if (ctx.mode.is(Mode.QuotedPattern) && vdef.name.startsWith("$")) {
170+
val name = vdef.name.toString.substring(1).toTermName
171+
val mods = vdef.mods.withAddedAnnotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(vdef.span))
172+
cpy.ValDef(vdef)(name).withMods(mods)
173+
} else vdef
174+
}
175+
167176
def makeImplicitParameters(tpts: List[Tree], contextualFlag: FlagSet = EmptyFlags, forPrimaryConstructor: Boolean = false)(implicit ctx: Context): List[ValDef] =
168177
for (tpt <- tpts) yield {
169178
val paramFlags: FlagSet = if (forPrimaryConstructor) PrivateLocalParamAccessor else Param

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ class Definitions {
722722
def InternalQuoted_typeQuote(implicit ctx: Context): Symbol = InternalQuoted_typeQuoteR.symbol
723723
lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole")
724724
def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol
725+
lazy val InternalQuoted_patternBindHoleAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternBindHole")
725726

726727
lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher")
727728
def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol
@@ -741,6 +742,9 @@ class Definitions {
741742
lazy val QuotedTypeModuleRef: TermRef = ctx.requiredModuleRef("scala.quoted.Type")
742743
def QuotedTypeModule(implicit ctx: Context): Symbol = QuotedTypeModuleRef.symbol
743744

745+
lazy val QuotedMatchingBindingType: TypeRef = ctx.requiredClassRef("scala.quoted.matching.Binding")
746+
def QuotedMatchingBindingClass(implicit ctx: Context): ClassSymbol = QuotedMatchingBindingType.symbol.asClass
747+
744748
def Unpickler_unpickleExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleExpr")
745749
def Unpickler_liftedExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.liftedExpr")
746750
def Unpickler_unpickleType: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleType")

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ object Parsers {
410410
makeParameter(name.asTermName, TypeTree()).withSpan(tree.span)
411411
case Typed(Ident(name), tpt) =>
412412
makeParameter(name.asTermName, tpt).withSpan(tree.span)
413+
case Typed(Splice(Ident(name)), tpt) =>
414+
makeParameter(("$" + name).toTermName, tpt).withSpan(tree.span)
413415
case _ =>
414416
syntaxError(s"not a legal $expected", tree.span)
415417
makeParameter(nme.ERROR, tree)

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
18101810
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass
18111811

18121812
def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
1813+
def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol = defn.InternalQuoted_patternBindHoleAnnot
18131814

18141815
// Types
18151816

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,6 +1959,7 @@ class Typer extends Namer
19591959
}
19601960

19611961
def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = {
1962+
val ctx0 = ctx
19621963
object splitter extends tpd.TreeMap {
19631964
val patBuf = new mutable.ListBuffer[Tree]
19641965
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
@@ -1973,6 +1974,13 @@ class Typer extends Namer
19731974
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
19741975
patBuf += pat1
19751976
}
1977+
case vdef: ValDef =>
1978+
if (vdef.symbol.annotations.exists(_.symbol == defn.InternalQuoted_patternBindHoleAnnot)) {
1979+
val tpe = AppliedType(defn.QuotedMatchingBindingType, vdef.tpt.tpe :: Nil)
1980+
val sym = ctx0.newPatternBoundSymbol(vdef.name, tpe, vdef.span)
1981+
patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(tpe)).withSpan(vdef.span)
1982+
}
1983+
super.transform(tree)
19761984
case _ =>
19771985
super.transform(tree)
19781986
}

library/src-bootstrapped/scala/internal/Quoted.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package scala.internal
22

3+
import scala.annotation.Annotation
34
import scala.quoted._
45

56
object Quoted {
@@ -19,4 +20,8 @@ object Quoted {
1920
/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
2021
def patternHole[T]: T =
2122
throw new Error("Internal error: this method call should have been replaced by the compiler")
23+
24+
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
25+
class patternBindHole extends Annotation
26+
2227
}

library/src-bootstrapped/scala/internal/quoted/Matcher.scala

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package scala.internal.quoted
33
import scala.annotation.internal.sharable
44

55
import scala.quoted._
6+
import scala.quoted.matching.Binding
67
import scala.tasty._
78

89
object Matcher {
@@ -51,6 +52,18 @@ object Matcher {
5152
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
5253
}
5354

55+
def bindingMatch(sym: Symbol) =
56+
Some(Tuple1(new Binding(sym.name, sym)))
57+
58+
def hasBindingTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
59+
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
60+
case Annotated(tpt2, _) => hasBindingTypeAnnotation(tpt2)
61+
case _ => false
62+
}
63+
64+
def hasBindingAnnotation(sym: Symbol) =
65+
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
66+
5467
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
5568
if (scrutinees.size != patterns.size) None
5669
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
@@ -142,24 +155,30 @@ object Matcher {
142155
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
143156

144157
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
158+
val bindMatch =
159+
if (hasBindingAnnotation(pattern.symbol) || hasBindingTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
160+
else Some(())
145161
val returnTptMatch = treeMatches(tpt1, tpt2)
146162
val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
147163
val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
148-
foldMatchings(returnTptMatch, rhsMatchings)
164+
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
149165

150166
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
151167
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
152168
val paramssMatch =
153169
if (paramss1.size != paramss2.size) None
154170
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
171+
val bindMatch =
172+
if (hasBindingAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
173+
else Some(())
155174
val tptMatch = treeMatches(tpt1, tpt2)
156175
val rhsEnv =
157176
env + (scrutinee.symbol -> pattern.symbol) ++
158-
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
159-
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
177+
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
178+
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
160179
val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
161180

162-
foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
181+
foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
163182

164183
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
165184
// TODO match tpt1 with tpt2?
@@ -180,6 +199,10 @@ object Matcher {
180199
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
181200
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
182201

202+
// Ignore type annotations
203+
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
204+
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)
205+
183206
// No Match
184207
case _ =>
185208
if (debug)

library/src-non-bootstrapped/scala/internal/Quoted.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package scala.internal
22

3+
import scala.annotation.Annotation
34
import scala.quoted._
45

56
object Quoted {
@@ -16,4 +17,11 @@ object Quoted {
1617
def typeQuote[T/* <: AnyKind */]: Type[T] =
1718
throw new Error("Internal error: this method call should have been replaced by the compiler")
1819

20+
/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
21+
def patternHole[T]: T =
22+
throw new Error("Internal error: this method call should have been replaced by the compiler")
23+
24+
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
25+
class patternBindHole extends Annotation
26+
1927
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package scala.quoted
2+
package matching
3+
4+
import scala.tasty.Reflection // TODO do not depend on reflection directly
5+
6+
/** Binding of an Expr[T] used to know if some Expr[T] is a reference to the binding
7+
*
8+
* @param name string name of this binding
9+
* @param id unique id used for equality
10+
*/
11+
class Binding[-T] private[scala](val name: String, private[Binding] val id: Object) { self =>
12+
13+
override def equals(obj: Any): Boolean = obj match {
14+
case obj: Binding[_] => obj.id == id
15+
case _ => false
16+
}
17+
18+
override def hashCode(): Int = id.hashCode()
19+
20+
}
21+
22+
object Binding {
23+
24+
def unapply[T](expr: Expr[T])(implicit reflect: Reflection): Option[Binding[T]] = {
25+
import reflect._
26+
expr.unseal match {
27+
case IsIdent(ref) =>
28+
val sym = ref.symbol
29+
Some(new Binding[T](sym.name, sym))
30+
case _ => None
31+
}
32+
}
33+
34+
}

library/src/scala/tasty/reflect/Kernel.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1469,9 +1469,12 @@ trait Kernel {
14691469

14701470
def Definitions_TupleClass(arity: Int): Symbol
14711471

1472-
/** Symbol of scala.runtime.Quoted.patternHole */
1472+
/** Symbol of scala.internal.Quoted.patternHole */
14731473
def Definitions_InternalQuoted_patternHole: Symbol
14741474

1475+
/** Symbol of scala.internal.Quoted.patternBindHole */
1476+
def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol
1477+
14751478
def Definitions_UnitType: Type
14761479
def Definitions_ByteType: Type
14771480
def Definitions_ShortType: Type

tests/neg/quotedPatterns-2.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
object Test {
22
def test(x: quoted.Expr[Int]) given tasty.Reflection = x match {
3-
case '{ val a = 4; '{ a }; $y } => y // error: access to value a from wrong staging level
3+
// case '{ val a = 4; '{ a }; $y } => y // error: access to value a from wrong staging level
4+
case '{ val `$y`: Int = 2; 1 } =>
5+
y // error
6+
case '{ ((`$y`: Int) => 3); 2 } =>
7+
y // error
48
case _ =>
59
}
610
}

tests/pos/quotedPatterns.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@ object Test {
1111
case '{g($y, $z)} => '{$y * $z}
1212
case '{ ((a: Int) => 3)($y) } => y
1313
case '{ 1 + ($y: Int)} => y
14+
case '{ val a = 1 + ($y: Int); 3 } => y
1415
// currently gives an unreachable case warning
1516
// but only when used in conjunction with the others.
1617
// I believe this is because implicit arguments are not taken
1718
// into account when checking whether we have already seen an `unapply` before.
19+
case '{ val $y: Int = $z; 1 } =>
20+
val a: quoted.matching.Binding[Int] = y
21+
z
22+
case '{ (($y: Int) => 1 + y + ($z: Int))(2) } =>
23+
val a: quoted.matching.Binding[Int] = y
24+
z
1825
case _ => '{1}
1926
}
2027
}

0 commit comments

Comments
 (0)