Skip to content

Commit 3cb8b72

Browse files
committed
Fix bugs in QuoteMatcher::MatchResult::toExpr
1 parent 6764d5e commit 3cb8b72

File tree

1 file changed

+43
-54
lines changed

1 file changed

+43
-54
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.*
1111
import dotty.tools.dotc.core.StdNames.nme
1212
import dotty.tools.dotc.core.Symbols.*
1313
import dotty.tools.dotc.util.optional
14+
import dotty.tools.dotc.ast.TreeTypeMap
1415

1516
/** Matches a quoted tree against a quoted pattern tree.
1617
* A quoted pattern tree may have type and term holes in addition to normal terms.
@@ -319,9 +320,9 @@ class QuoteMatcher(debug: Boolean) {
319320

320321
val env = summon[Env]
321322
val capturedIds = args.map(getCapturedIdent)
322-
val capturedSymbols = capturedIds.map(_.symbol)
323323
val capturedTargs = unrollHkNestedPairsTypeTree(targs)
324-
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v) && !capturedTargs.map(_.symbol).contains(v))
324+
val capturedSymbols = Set.from(capturedIds.map(_.symbol) ++ capturedTargs.map(_.symbol))
325+
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
325326
withEnv(captureEnv) {
326327
scrutinee match
327328
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), capturedTargs.map(_.tpe), env)
@@ -581,18 +582,17 @@ class QuoteMatcher(debug: Boolean) {
581582
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
582583
def freePatternVars(term: Tree)(using Env, Context): Set[Symbol] =
583584
val typeAccumulator = new TypeAccumulator[Set[Symbol]] {
584-
def apply(x: Set[Symbol], tp: Type): Set[Symbol] =
585-
if summon[Env].contains(tp.typeSymbol) then
586-
foldOver(x + tp.typeSymbol, tp)
587-
else
588-
foldOver(x, tp)
585+
def apply(x: Set[Symbol], tp: Type): Set[Symbol] = tp match
586+
case tp: TypeRef if summon[Env].contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp)
587+
case tp: TermRef if summon[Env].contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp)
588+
case _ => foldOver(x, tp)
589589
}
590590
val treeAccumulator = new TreeAccumulator[Set[Symbol]] {
591591
def apply(x: Set[Symbol], tree: Tree)(using Context): Set[Symbol] =
592-
val tvars = typeAccumulator(Set.empty, tree.tpe)
593592
tree match
594-
case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(x ++ tvars + tree.symbol, tree)
595-
case _ => foldOver(x ++ tvars, tree)
593+
case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree)
594+
case tree: TypeTree => typeAccumulator(x, tree.tpe)
595+
case _ => foldOver(x, tree)
596596
}
597597
treeAccumulator(Set.empty, term)
598598
}
@@ -625,49 +625,38 @@ class QuoteMatcher(debug: Boolean) {
625625
case MatchResult.ClosedTree(tree) =>
626626
new ExprImpl(tree, spliceScope)
627627
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) =>
628-
if typeArgs.isEmpty then
629-
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
630-
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
631-
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
632-
val meth = newAnonFun(ctx.owner, methTpe)
633-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
634-
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
635-
val body = new TreeMap {
636-
override def transform(tree: Tree)(using Context): Tree =
637-
tree match
638-
/*
639-
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
640-
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
641-
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
642-
*/
643-
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
644-
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
645-
case tree => super.transform(tree)
646-
}.transform(tree)
647-
TreeOps(body).changeNonLocalOwners(meth)
648-
}
649-
val hoasClosure = Closure(meth, bodyFn)
650-
new ExprImpl(hoasClosure, spliceScope)
651-
else
652-
// TODO-18271: This implementation fails Typer.assertPositioned.
653-
// We want to find safe way to generate poly function
654-
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
655-
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
628+
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
629+
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
630+
val ptTypeVarSymbols = typeArgs.map(_.typeSymbol)
656631

632+
val methTpe = if typeArgs.isEmpty then
633+
MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
634+
else
657635
val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length)
658636
val bounds = typeArgs map (_ => TypeBounds.empty)
659-
val fromSymbols = typeArgs.map(_.typeSymbol)
660637
val resultTypeExp = (pt: PolyType) => {
661-
val argTypes1 = paramTypes.map(_.subst(fromSymbols, pt.paramRefs))
662-
val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs)
638+
val argTypes1 = paramTypes.map(_.subst(ptTypeVarSymbols, pt.paramRefs))
639+
val resultType1 = mapTypeHoles(patternTpe).subst(ptTypeVarSymbols, pt.paramRefs)
663640
MethodType(argTypes1, resultType1)
664641
}
665-
val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp)
666-
val meth = newAnonFun(ctx.owner, methTpe)
667-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
668-
val typeArgs = lambdaArgss.head
669-
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap
670-
val body = new TreeMap {
642+
PolyType(typeArgs1)(_ => bounds, resultTypeExp)
643+
644+
val meth = newAnonFun(ctx.owner, methTpe)
645+
646+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
647+
val typeArgsMap = ptTypeVarSymbols.zip(lambdaArgss.head.map(_.tpe)).toMap
648+
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap
649+
650+
val body = new TreeTypeMap(
651+
typeMap = if typeArgs.isEmpty then IdentityTypeMap
652+
else new TypeMap() {
653+
override def apply(tp: Type): Type = tp match {
654+
case tr: TypeRef if tr.prefix.eq(NoPrefix) =>
655+
env.get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr)
656+
case tp => mapOver(tp)
657+
}
658+
},
659+
treeMap = new TreeMap {
671660
override def transform(tree: Tree)(using Context): Tree =
672661
tree match
673662
/*
@@ -678,13 +667,13 @@ class QuoteMatcher(debug: Boolean) {
678667
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
679668
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
680669
case tree => super.transform(tree)
681-
}
682-
.transform(tree)
683-
.subst(fromSymbols, typeArgs.map(_.symbol))
684-
TreeOps(body).changeNonLocalOwners(meth)
685-
}
686-
val hoasClosure = Closure(meth, bodyFn)
687-
new ExprImpl(hoasClosure, spliceScope)
670+
}.transform
671+
).transform(tree)
672+
673+
TreeOps(body).changeNonLocalOwners(meth)
674+
}
675+
val hoasClosure = Closure(meth, bodyFn).withSpan(tree.span)
676+
new ExprImpl(hoasClosure, spliceScope)
688677

689678
private inline def notMatched[T]: optional[T] =
690679
optional.break()

0 commit comments

Comments
 (0)