Skip to content

Use contextual types in quote matcher #6380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions library/src-3.x/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,19 @@ object Matcher {
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
import reflection.{Bind => BindPattern, _}

type Env = Set[(Symbol, Symbol)]

// TODO improve performance

/** Check that the trees match and return the contents from the pattern holes.
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
*
* @param scrutinee The tree beeing matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param env Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
*/
def treeMatches(scrutinee: Tree, pattern: Tree)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = {

/** Check that both are `val` or both are `lazy val` or both are `var` **/
def checkValFlags(): Boolean = {
Expand Down Expand Up @@ -101,7 +103,7 @@ object Matcher {
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))

case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || env((scrutinee.symbol, pattern.symbol)) =>
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
Some(())

case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
Expand Down Expand Up @@ -160,8 +162,8 @@ object Matcher {
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
else Some(())
val returnTptMatch = treeMatches(tpt1, tpt2)
val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)

case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
Expand All @@ -174,10 +176,10 @@ object Matcher {
else Some(())
val tptMatch = treeMatches(tpt1, tpt2)
val rhsEnv =
env + (scrutinee.symbol -> pattern.symbol) ++
the[Env] + (scrutinee.symbol -> pattern.symbol) ++
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv

foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)

Expand Down Expand Up @@ -227,19 +229,23 @@ object Matcher {
}
}

def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = {
(scrutinee, pattern) match {
case (Some(x), Some(y)) => treeMatches(x, y)
case (None, None) => Some(())
case _ => None
}
}

def caseMatches(scrutinee: CaseDef, pattern: CaseDef)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = {
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)(caseEnv)
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)(caseEnv)
foldMatchings(patternMatch, guardMatch, rhsMatch)

{
implied for Env = caseEnv
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)
foldMatchings(patternMatch, guardMatch, rhsMatch)
}
}

/** Check that the pattern trees match and return the contents from the pattern holes.
Expand All @@ -248,21 +254,21 @@ object Matcher {
*
* @param scrutinee The pattern tree beeing matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param env Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return The new environment containing the bindings defined in this pattern tuppled with
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
*/
def patternMatches(scrutinee: Pattern, pattern: Pattern)(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = (scrutinee, pattern) match {
def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match {
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
(env, Some(Tuple1(v1.seal)))
(the[Env], Some(Tuple1(v1.seal)))

case (Pattern.Value(v1), Pattern.Value(v2)) =>
(env, treeMatches(v1, v2))
(the[Env], treeMatches(v1, v2))

case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
val bindEnv = env + (scrutinee.symbol -> pattern.symbol)
patternMatches(body1, body2)(bindEnv)
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
patternMatches(body1, body2) given bindEnv

case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
val funMatch = treeMatches(fun1, fun2)
Expand All @@ -276,10 +282,10 @@ object Matcher {
foldPatterns(patterns1, patterns2)

case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
(env, treeMatches(tpt1, tpt2))
(the[Env], treeMatches(tpt1, tpt2))

case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) =>
(env, Some(()))
(the[Env], Some(()))

case _ =>
if (debug)
Expand All @@ -299,18 +305,19 @@ object Matcher {
|
|
|""".stripMargin)
(env, None)
(the[Env], None)
}

def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern])(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = {
if (patterns1.size != patterns2.size) (env, None)
else patterns1.zip(patterns2).foldLeft((env, Option[Tuple](()))) { (acc, x) =>
val (env, res) = patternMatches(x._1, x._2)(acc._1)
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = {
if (patterns1.size != patterns2.size) (the[Env], None)
else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) =>
val (env, res) = patternMatches(x._1, x._2) given acc._1
(env, foldMatchings(acc._2, res))
}
}

treeMatches(scrutineeExpr.unseal, patternExpr.unseal)(Set.empty).asInstanceOf[Option[Tup]]
implied for Env = Set.empty
treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]]
}

/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
Expand Down