@@ -244,14 +244,14 @@ class QuoteMatcher(debug: Boolean) {
244
244
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole ) &&
245
245
tpt2.tpe.derivesFrom(defn.RepeatedParamClass ) =>
246
246
scrutinee match
247
- case Typed (s, tpt1) if s.tpe <:< tpt.tpe => matched(scrutinee)
247
+ case Typed (s, tpt1) if isSubTypeUnderEnv(s, tpt) => matched(scrutinee)
248
248
case _ => notMatched
249
249
250
250
/* Term hole */
251
251
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
252
252
case TypeApply (patternHole, tpt :: Nil )
253
253
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole ) &&
254
- scrutinee.tpe <:< tpt.tpe =>
254
+ isSubTypeUnderEnv( scrutinee, tpt) =>
255
255
scrutinee match
256
256
case ClosedPatternTerm (scrutinee) => matched(scrutinee)
257
257
case _ => notMatched
@@ -360,7 +360,7 @@ class QuoteMatcher(debug: Boolean) {
360
360
/* Match reference */
361
361
case _ : Ident if symbolMatch(scrutinee, pattern) => matched
362
362
/* Match type */
363
- case TypeTreeTypeTest (pattern) if scrutinee.tpe <:< pattern.tpe => matched
363
+ case TypeTreeTypeTest (pattern) if isSubTypeUnderEnv( scrutinee, pattern) => matched
364
364
case _ => notMatched
365
365
366
366
/* Match application */
@@ -439,7 +439,7 @@ class QuoteMatcher(debug: Boolean) {
439
439
// TODO remove this?
440
440
case TypeTreeTypeTest (scrutinee) =>
441
441
pattern match
442
- case TypeTreeTypeTest (pattern) if scrutinee.tpe <:< pattern.tpe => matched
442
+ case TypeTreeTypeTest (pattern) if isSubTypeUnderEnv( scrutinee, pattern) => matched
443
443
case _ => notMatched
444
444
445
445
/* Match val */
@@ -476,8 +476,13 @@ class QuoteMatcher(debug: Boolean) {
476
476
case (scparams :: screst, ptparams :: ptrest) =>
477
477
(scparams, ptparams) match
478
478
case (TypeDefs (scparams), TypeDefs (ptparams)) =>
479
- matchTypeParams(scparams, ptparams)
480
- matchParamss(screst, ptrest)
479
+ scparams.foreach(tdef => println(s " tdef.rhs = ${tdef.rhs.show}" ))
480
+ if scparams.exists(tdef => tdef.rhs.isEmpty) then
481
+ notMatched
482
+
483
+ val newEnv = summon[Env ] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
484
+ val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
485
+ (resEnv, mrrest)
481
486
case (ValDefs (scparams), ValDefs (ptparams)) =>
482
487
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
483
488
val newEnv = summon[Env ] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
@@ -569,20 +574,32 @@ class QuoteMatcher(debug: Boolean) {
569
574
|| summon[Env ].get(devirtualizedScrutinee).contains(pattern)
570
575
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
571
576
577
+ private def isSubTypeUnderEnv (scrutinee : Tree , pattern : Tree )(using Env , Context ): Boolean =
578
+ val env = summon[Env ]
579
+ scrutinee.subst(env.keys.toList, env.values.toList).tpe <:< pattern.tpe
580
+
572
581
private object ClosedPatternTerm {
573
582
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */
574
583
def unapply (term : Tree )(using Env , Context ): Option [term.type ] =
575
584
if freePatternVars(term).isEmpty then Some (term) else None
576
585
577
586
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
578
587
def freePatternVars (term : Tree )(using Env , Context ): Set [Symbol ] =
579
- val accumulator = new TreeAccumulator [Set [Symbol ]] {
588
+ val typeAccumulator = new TypeAccumulator [Set [Symbol ]] {
589
+ def apply (x : Set [Symbol ], tp : Type ): Set [Symbol ] =
590
+ if summon[Env ].contains(tp.typeSymbol) then
591
+ foldOver(x + tp.typeSymbol, tp)
592
+ else
593
+ foldOver(x, tp)
594
+ }
595
+ val treeAccumulator = new TreeAccumulator [Set [Symbol ]] {
580
596
def apply (x : Set [Symbol ], tree : Tree )(using Context ): Set [Symbol ] =
597
+ val tvars = typeAccumulator(Set .empty, tree.tpe)
581
598
tree match
582
- case tree : Ident if summon[Env ].contains(tree.symbol) => foldOver(x + tree.symbol, tree)
583
- case _ => foldOver(x, tree)
599
+ case tree : Ident if summon[Env ].contains(tree.symbol) => foldOver(x ++ tvars + tree.symbol, tree)
600
+ case _ => foldOver(x ++ tvars , tree)
584
601
}
585
- accumulator.apply (Set .empty, term)
602
+ treeAccumulator (Set .empty, term)
586
603
}
587
604
588
605
enum MatchResult :
@@ -685,16 +702,6 @@ class QuoteMatcher(debug: Boolean) {
685
702
private def matchedOpen (tree : Tree , patternTpe : Type , argIds : List [Tree ], argTypes : List [Type ], typeArgs : List [Type ], env : Env )(using Context ): MatchingExprs =
686
703
Seq (MatchResult .OpenTree (tree, patternTpe, argIds, argTypes, typeArgs, env))
687
704
688
- // private def unifySyms(params1: List[Symbol], params2: List[Symbol])(using Context) =
689
- // ctx.gadtState.addToConstraint(params1)
690
- // ctx.gadtState.addToConstraint(params2)
691
- // val paramrefs1 = params1 map (ctx.gadt.tvarOrError(_))
692
- // val paramrefs2 = params2 map (ctx.gadt.tvarOrError(_))
693
- // for ((p1, p2) <- paramrefs1.zip(paramrefs2))
694
- // do
695
- // p1 <:< p2
696
- // p2 <:< p1
697
-
698
705
extension (self : MatchingExprs )
699
706
/** Concatenates the contents of two successful matchings */
700
707
def &&& (that : MatchingExprs ): MatchingExprs = self ++ that
0 commit comments