diff --git a/library/src-3.x/scala/internal/quoted/Matcher.scala b/library/src-3.x/scala/internal/quoted/Matcher.scala index 34854b093128..52104fd48fa1 100644 --- a/library/src-3.x/scala/internal/quoted/Matcher.scala +++ b/library/src-3.x/scala/internal/quoted/Matcher.scala @@ -33,6 +33,8 @@ object Matcher { def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = { import reflection.{Bind => BindPattern, _} + type Env = Set[(Symbol, Symbol)] + // TODO improve performance /** Check that the trees match and return the contents from the pattern holes. @@ -40,10 +42,10 @@ object Matcher { * * @param scrutinee The tree beeing matched * @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. - * @param env Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. + * @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ - def treeMatches(scrutinee: Tree, pattern: Tree)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = { + def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = { /** Check that both are `val` or both are `lazy val` or both are `var` **/ def checkValFlags(): Boolean = { @@ -101,7 +103,7 @@ object Matcher { case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2)) - case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || env((scrutinee.symbol, pattern.symbol)) => + case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) => Some(()) case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol => @@ -160,8 +162,8 @@ object Matcher { if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol) else Some(()) val returnTptMatch = treeMatches(tpt1, tpt2) - val rhsEnv = env + (scrutinee.symbol -> pattern.symbol) - val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv) + val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) + val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv foldMatchings(bindMatch, returnTptMatch, rhsMatchings) case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) => @@ -174,10 +176,10 @@ object Matcher { else Some(()) val tptMatch = treeMatches(tpt1, tpt2) val rhsEnv = - env + (scrutinee.symbol -> pattern.symbol) ++ + the[Env] + (scrutinee.symbol -> pattern.symbol) ++ typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++ paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol) - val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv) + val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch) @@ -227,7 +229,7 @@ object Matcher { } } - def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = { + def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = { (scrutinee, pattern) match { case (Some(x), Some(y)) => treeMatches(x, y) case (None, None) => Some(()) @@ -235,11 +237,15 @@ object Matcher { } } - def caseMatches(scrutinee: CaseDef, pattern: CaseDef)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = { + def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = { val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern) - val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)(caseEnv) - val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)(caseEnv) - foldMatchings(patternMatch, guardMatch, rhsMatch) + + { + implied for Env = caseEnv + val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard) + val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs) + foldMatchings(patternMatch, guardMatch, rhsMatch) + } } /** Check that the pattern trees match and return the contents from the pattern holes. @@ -248,21 +254,21 @@ object Matcher { * * @param scrutinee The pattern tree beeing matched * @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. - * @param env Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. + * @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. * @return The new environment containing the bindings defined in this pattern tuppled with * `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ - def patternMatches(scrutinee: Pattern, pattern: Pattern)(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = (scrutinee, pattern) match { + def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match { case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" => - (env, Some(Tuple1(v1.seal))) + (the[Env], Some(Tuple1(v1.seal))) case (Pattern.Value(v1), Pattern.Value(v2)) => - (env, treeMatches(v1, v2)) + (the[Env], treeMatches(v1, v2)) case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) => - val bindEnv = env + (scrutinee.symbol -> pattern.symbol) - patternMatches(body1, body2)(bindEnv) + val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) + patternMatches(body1, body2) given bindEnv case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) => val funMatch = treeMatches(fun1, fun2) @@ -276,10 +282,10 @@ object Matcher { foldPatterns(patterns1, patterns2) case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) => - (env, treeMatches(tpt1, tpt2)) + (the[Env], treeMatches(tpt1, tpt2)) case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) => - (env, Some(())) + (the[Env], Some(())) case _ => if (debug) @@ -299,18 +305,19 @@ object Matcher { | | |""".stripMargin) - (env, None) + (the[Env], None) } - def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern])(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = { - if (patterns1.size != patterns2.size) (env, None) - else patterns1.zip(patterns2).foldLeft((env, Option[Tuple](()))) { (acc, x) => - val (env, res) = patternMatches(x._1, x._2)(acc._1) + def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = { + if (patterns1.size != patterns2.size) (the[Env], None) + else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) => + val (env, res) = patternMatches(x._1, x._2) given acc._1 (env, foldMatchings(acc._2, res)) } } - treeMatches(scrutineeExpr.unseal, patternExpr.unseal)(Set.empty).asInstanceOf[Option[Tup]] + implied for Env = Set.empty + treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]] } /** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.