Skip to content

Commit dc3b995

Browse files
committed
Add quoted pattern type splices runtime
1 parent c1c026d commit dc3b995

File tree

6 files changed

+188
-24
lines changed

6 files changed

+188
-24
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
3535

3636
def Context_source(self: Context): java.nio.file.Path = self.compilationUnit.source.file.jpath
3737

38+
def Context_GADT_setFreshGADTBounds(self: Context): Context =
39+
self.fresh.setFreshGADTBounds.addMode(Mode.GadtConstraintInference)
40+
41+
def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean =
42+
self.gadt.addToConstraint(syms)
43+
44+
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type =
45+
self.gadt.approximation(sym, fromBelow)
46+
3847
//
3948
// REPORTING
4049
//

library/src-3.x/scala/internal/Quoted.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ object Quoted {
2525
@compileTimeOnly("Illegal reference to `scala.internal.Quoted.patternBindHole`")
2626
class patternBindHole extends Annotation
2727

28+
/** A splice of a name in a quoted pattern is that marks the definition of a type splice */
29+
class patternType extends Annotation
30+
2831
/** Artifact of pickled type splices
2932
*
3033
* During quote reification a quote `'{ ... F[$t] ... }` will be transformed into

library/src-3.x/scala/internal/quoted/Matcher.scala

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ object Matcher {
4545
}
4646

4747
/** Check that all trees match with =#= and concatenate the results with && */
48-
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching =
48+
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Context, Env: Matching =
4949
matchLists(scrutinees, patterns)(_ =#= _)
5050

5151
/** Check that the trees match and return the contents from the pattern holes.
@@ -56,7 +56,17 @@ object Matcher {
5656
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
5757
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
5858
*/
59-
def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = {
59+
def (scrutinee0: Tree) =#= (pattern0: Tree) given Context, Env: Matching = {
60+
61+
/** Normalieze the tree */
62+
def normalize(tree: Tree): Tree = tree match {
63+
case Block(Nil, expr) => normalize(expr)
64+
case Inlined(_, Nil, expr) => normalize(expr)
65+
case _ => tree
66+
}
67+
68+
val scrutinee = normalize(scrutinee0)
69+
val pattern = normalize(pattern0)
6070

6171
/** Check that both are `val` or both are `lazy val` or both are `var` **/
6272
def checkValFlags(): Boolean = {
@@ -78,14 +88,7 @@ object Matcher {
7888
def hasBindAnnotation(sym: Symbol) =
7989
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
8090

81-
/** Normalieze the tree */
82-
def normalize(tree: Tree): Tree = tree match {
83-
case Block(Nil, expr) => normalize(expr)
84-
case Inlined(_, Nil, expr) => normalize(expr)
85-
case _ => tree
86-
}
87-
88-
(normalize(scrutinee), normalize(pattern)) match {
91+
(scrutinee, pattern) match {
8992

9093
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
9194
case (IsTerm(scrutinee @ Typed(s, tpt1)), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
@@ -110,6 +113,9 @@ object Matcher {
110113
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
111114
expr1 =#= expr2 && tpt1 =#= tpt2
112115

116+
case (scrutinee, Typed(expr2, _)) =>
117+
scrutinee =#= expr2
118+
113119
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
114120
matched
115121

@@ -142,9 +148,6 @@ object Matcher {
142148
case (While(cond1, body1), While(cond2, body2)) =>
143149
cond1 =#= cond2 && body1 =#= body2
144150

145-
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
146-
expr1 =#= expr2
147-
148151
case (New(tpt1), New(tpt2)) =>
149152
tpt1 =#= tpt2
150153

@@ -157,10 +160,11 @@ object Matcher {
157160
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
158161
elems1 =##= elems2
159162

163+
// TODO is this case required
160164
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
161165
matched
162166

163-
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
167+
case (IsTypeTree(scrutinee), IsTypeTree(pattern)) if scrutinee.tpe <:< pattern.tpe =>
164168
matched
165169

166170
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
@@ -171,7 +175,7 @@ object Matcher {
171175
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
172176
else matched
173177
def rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
174-
bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given rhsEnv)
178+
bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given (the[Context], rhsEnv))
175179

176180
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
177181
val bindMatch =
@@ -227,15 +231,15 @@ object Matcher {
227231
}
228232
}
229233

230-
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = {
234+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Context, Env: Matching = {
231235
(scrutinee, pattern) match {
232236
case (Some(x), Some(y)) => x =#= y
233237
case (None, None) => matched
234238
case _ => notMatched
235239
}
236240
}
237241

238-
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
242+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Context, Env: Matching = {
239243
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
240244
withEnv(caseEnv) {
241245
patternMatch &&
@@ -254,7 +258,7 @@ object Matcher {
254258
* @return The new environment containing the bindings defined in this pattern tuppled with
255259
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
256260
*/
257-
def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match {
261+
def (scrutinee: Pattern) =%= (pattern: Pattern) given Context, Env: (Env, Matching) = (scrutinee, pattern) match {
258262
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
259263
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
260264
(the[Env], matched(v1.seal))
@@ -264,7 +268,7 @@ object Matcher {
264268

265269
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
266270
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
267-
(body1 =%= body2) given bindEnv
271+
(body1 =%= body2) given (the[Context], bindEnv)
268272

269273
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
270274
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
@@ -300,16 +304,33 @@ object Matcher {
300304
(the[Env], notMatched)
301305
}
302306

303-
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = {
307+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Context, Env: (Env, Matching) = {
304308
if (patterns1.size != patterns2.size) (the[Env], notMatched)
305309
else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) =>
306-
val (env, res) = (x._1 =%= x._2) given acc._1
310+
val (env, res) = (x._1 =%= x._2) given (the[Context], acc._1)
307311
(env, acc._2 && res)
308312
}
309313
}
310314

315+
def isTypeBinding(tree: Tree): Boolean = tree match {
316+
case IsTypeDef(tree) =>
317+
tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.patternType")
318+
case _ => false
319+
}
320+
311321
implicit val env: Env = Set.empty
312-
(scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]]
322+
val res = patternExpr.unseal.underlyingArgument match {
323+
case Block(typeBindings, pattern) if typeBindings.forall(isTypeBinding) =>
324+
implicit val ctx2 = reflection.kernel.Context_GADT_setFreshGADTBounds(rootContext)
325+
val bindingSymbols = typeBindings.map(_.symbol(ctx2))
326+
reflection.kernel.Context_GADT_addToConstraint(ctx2)(bindingSymbols)
327+
val matchings = scrutineeExpr.unseal.underlyingArgument =#= pattern
328+
val constainedTypes = bindingSymbols.map(s => reflection.kernel.Context_GADT_approximation(ctx2)(s, true))
329+
constainedTypes.foldRight(matchings)((x, acc) => matched(x.seal) && acc)
330+
case pattern =>
331+
scrutineeExpr.unseal.underlyingArgument =#= pattern
332+
}
333+
res.asOptionOfTuple.asInstanceOf[Option[Tup]]
313334
}
314335

315336
/** Result of matching a part of an expression */

library/src/scala/tasty/reflect/Kernel.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ trait Kernel {
140140
/** Returns the source file being compiled. The path is relative to the current working directory. */
141141
def Context_source(self: Context): java.nio.file.Path
142142

143+
def Context_GADT_setFreshGADTBounds(self: Context): Context
144+
def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean
145+
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type
146+
143147
//
144148
// REPORTING
145149
//

tests/run-macros/quote-matcher-runtime.check

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Result: Some(List())
1616

1717
Scrutinee: 1
1818
Pattern: (1: scala.Int)
19-
Result: None
19+
Result: Some(List())
2020

2121
Scrutinee: 3
2222
Pattern: scala.internal.Quoted.patternHole[scala.Int]
@@ -714,3 +714,118 @@ Pattern: try scala.internal.Quoted.patternHole[scala.Int] finally {
714714
}
715715
Result: Some(List(Expr(1), Expr(2)))
716716

717+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
718+
Pattern: {
719+
@scala.internal.Quoted.patternType type T
720+
scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]])
721+
}
722+
Result: Some(List(Type(scala.Unit), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int]))), Expr(((x: scala.Int) => scala.Predef.println(x)))))
723+
724+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
725+
Pattern: {
726+
@scala.internal.Quoted.patternType type T = scala.Unit
727+
scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]])
728+
}
729+
Result: Some(List(Type(scala.Unit), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int]))), Expr(((x: scala.Int) => scala.Predef.println(x)))))
730+
731+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
732+
Pattern: {
733+
@scala.internal.Quoted.patternType type T <: scala.Predef.String
734+
scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]])
735+
}
736+
Result: None
737+
738+
Scrutinee: {
739+
val a: scala.Int = 4
740+
val b: scala.Int = 4
741+
()
742+
}
743+
Pattern: {
744+
@scala.internal.Quoted.patternType type T
745+
val a: T = scala.internal.Quoted.patternHole[T]
746+
val b: T = scala.internal.Quoted.patternHole[T]
747+
()
748+
}
749+
Result: Some(List(Type(scala.Int), Expr(4), Expr(4)))
750+
751+
Scrutinee: {
752+
val a: scala.Int = 4
753+
val b: scala.Int = 5
754+
()
755+
}
756+
Pattern: {
757+
@scala.internal.Quoted.patternType type T
758+
val a: T = scala.internal.Quoted.patternHole[T]
759+
val b: T = scala.internal.Quoted.patternHole[T]
760+
()
761+
}
762+
Result: Some(List(Type(scala.Int), Expr(4), Expr(5)))
763+
764+
Scrutinee: {
765+
val a: scala.Int = 4
766+
val b: scala.Predef.String = "x"
767+
()
768+
}
769+
Pattern: {
770+
@scala.internal.Quoted.patternType type T
771+
val a: T = scala.internal.Quoted.patternHole[T]
772+
val b: T = scala.internal.Quoted.patternHole[T]
773+
()
774+
}
775+
Result: Some(List(Type(scala.Int | java.lang.String), Expr(4), Expr("x")))
776+
777+
Scrutinee: {
778+
val a: scala.Int = 4
779+
val b: scala.Predef.String = "x"
780+
()
781+
}
782+
Pattern: {
783+
@scala.internal.Quoted.patternType type T <: scala.Int
784+
val a: T = scala.internal.Quoted.patternHole[T]
785+
val b: T = scala.internal.Quoted.patternHole[T]
786+
()
787+
}
788+
Result: None
789+
790+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).map[scala.Double, scala.collection.immutable.List[scala.Double]](((x: scala.Int) => x.toDouble./(2)))(scala.collection.immutable.List.canBuildFrom[scala.Double]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((y: scala.Double) => y.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String])
791+
Pattern: {
792+
@scala.internal.Quoted.patternType type T
793+
@scala.internal.Quoted.patternType type U
794+
@scala.internal.Quoted.patternType type V
795+
796+
(scala.internal.Quoted.patternHole[scala.List[T]].map[U, scala.collection.immutable.List[U]](scala.internal.Quoted.patternHole[scala.Function1[T, U]])(scala.collection.immutable.List.canBuildFrom[U]).map[V, scala.collection.immutable.List[V]](scala.internal.Quoted.patternHole[scala.Function1[U, V]])(scala.collection.immutable.List.canBuildFrom[V]): scala.collection.immutable.List[scala.Any])
797+
}
798+
Result: Some(List(Type(scala.Int), Type(scala.Double), Type(java.lang.String), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int]))), Expr(((x: scala.Int) => x.toDouble./(2))), Expr(((y: scala.Double) => y.toString()))))
799+
800+
Scrutinee: ((x: scala.Int) => x)
801+
Pattern: {
802+
@scala.internal.Quoted.patternType type T
803+
804+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
805+
}
806+
Result: Some(List(Type(scala.Int), Expr(((x: scala.Int) => x))))
807+
808+
Scrutinee: ((x: scala.Int) => x.toString())
809+
Pattern: {
810+
@scala.internal.Quoted.patternType type T
811+
812+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
813+
}
814+
Result: None
815+
816+
Scrutinee: ((x: scala.Any) => scala.Predef.???)
817+
Pattern: {
818+
@scala.internal.Quoted.patternType type T
819+
820+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
821+
}
822+
Result: Some(List(Type(scala.Nothing), Expr(((x: scala.Any) => scala.Predef.???))))
823+
824+
Scrutinee: ((x: scala.Nothing) => (1: scala.Any))
825+
Pattern: {
826+
@scala.internal.Quoted.patternType type T
827+
828+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
829+
}
830+
Result: None
831+

tests/run-macros/quote-matcher-runtime/quoted_2.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Macros._
33

44
import scala.internal.quoted.Matcher._
55

6-
import scala.internal.Quoted.{patternHole, patternBindHole}
6+
import scala.internal.Quoted.{patternHole, patternBindHole, patternType}
77

88
object Test {
99

@@ -134,6 +134,18 @@ object Test {
134134
matches(try 1 finally 2, try 1 finally 2)
135135
matches(try 1 catch { case _ => 2 }, try patternHole[Int] catch { case _ => patternHole[Int] })
136136
matches(try 1 finally 2, try patternHole[Int] finally patternHole[Int])
137+
matches(List(1, 2, 3).foreach(x => println(x)), { @patternType type T; patternHole[List[Int]].foreach[T](patternHole[Int => T]) })
138+
matches(List(1, 2, 3).foreach(x => println(x)), { @patternType type T = Unit; patternHole[List[Int]].foreach[T](patternHole[Int => T]) })
139+
matches(List(1, 2, 3).foreach(x => println(x)), { @patternType type T <: String; patternHole[List[Int]].foreach[T](patternHole[Int => T]) })
140+
matches({ val a: Int = 4; val b: Int = 4 }, { @patternType type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
141+
matches({ val a: Int = 4; val b: Int = 5 }, { @patternType type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
142+
matches({ val a: Int = 4; val b: String = "x" }, { @patternType type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
143+
matches({ val a: Int = 4; val b: String = "x" }, { @patternType type T <: Int; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
144+
matches(List(1, 2, 3).map(x => x.toDouble / 2).map(y => y.toString), { @patternType type T; @patternType type U; @patternType type V; patternHole[List[T]].map(patternHole[T => U]).map(patternHole[U => V]) })
145+
matches((x: Int) => x, { @patternType type T; patternHole[T => T] })
146+
matches((x: Int) => x.toString, { @patternType type T; patternHole[T => T] })
147+
matches((x: Any) => ???, { @patternType type T; patternHole[T => T] })
148+
matches((x: Nothing) => (1 : Any), { @patternType type T; patternHole[T => T] })
137149

138150
}
139151
}

0 commit comments

Comments
 (0)