Skip to content

Commit c50ef24

Browse files
committed
Use tree checker for macro expanded trees
Trees are only checked if -Xcheck-macros is enabled. Fixes: - Add missing positions to {ValDef,Bind}.apply - Inline by-name ascribed param - Unbound type variables after implicit search
1 parent 2920a4f commit c50ef24

File tree

8 files changed

+87
-20
lines changed

8 files changed

+87
-20
lines changed

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class Inliner(val call: tpd.Tree)(using Context):
227227
val binding = {
228228
var newArg = arg.changeOwner(ctx.owner, boundSym)
229229
if bindingFlags.is(Inline) && argIsBottom then
230-
newArg = Typed(newArg, TypeTree(formal)) // type ascribe RHS to avoid type errors in expansion. See i8612.scala
230+
newArg = Typed(newArg, TypeTree(formal.widenExpr)) // type ascribe RHS to avoid type errors in expansion. See i8612.scala
231231
if isByName then DefDef(boundSym, newArg)
232232
else ValDef(boundSym, newArg)
233233
}.withSpan(boundSym.span)
@@ -816,6 +816,7 @@ class Inliner(val call: tpd.Tree)(using Context):
816816
&& StagingContext.level == 0
817817
&& !hasInliningErrors =>
818818
val expanded = expandMacro(res.args.head, tree.srcPos)
819+
transform.TreeChecker.checkMacroGeneratedTree(res, expanded)
819820
typedExpr(expanded) // Inline calls and constant fold code generated by the macro
820821
case res =>
821822
specializeEq(inlineIfNeeded(res, pt, locked))

compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ class MacroAnnotations(thisPhase: DenotTransformer):
8282
case (prefixed, newTree :: suffixed) =>
8383
allTrees ++= prefixed
8484
insertedAfter = suffixed :: insertedAfter
85-
prefixed.foreach(checkAndEnter(_, tree.symbol, annot))
86-
suffixed.foreach(checkAndEnter(_, tree.symbol, annot))
85+
prefixed.foreach(checkAndEnter(_, tree, annot))
86+
suffixed.foreach(checkAndEnter(_, tree, annot))
87+
transform.TreeChecker.checkMacroGeneratedTree(tree, newTree)
8788
newTree
8889
case (Nil, Nil) =>
8990
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
@@ -119,8 +120,10 @@ class MacroAnnotations(thisPhase: DenotTransformer):
119120
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])
120121

121122
/** Check that this tree can be added by the macro annotation and enter it if needed */
122-
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
123+
private def checkAndEnter(newTree: Tree, annotatedTree: Tree, annot: Annotation)(using Context) =
124+
transform.TreeChecker.checkMacroGeneratedTree(annotatedTree, newTree)
123125
val sym = newTree.symbol
126+
val annotated = annotatedTree.symbol
124127
if sym.isClass then
125128
report.error(i"macro annotation returning a `class` is not yet supported. $annot tried to add $sym", annot.tree)
126129
else if sym.isType then

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,6 @@ class TreeChecker extends Phase with SymTransformer {
105105
else if (ctx.phase.prev.isCheckable)
106106
check(ctx.base.allPhases.toIndexedSeq, ctx)
107107

108-
private def previousPhases(phases: List[Phase])(using Context): List[Phase] = phases match {
109-
case (phase: MegaPhase) :: phases1 =>
110-
val subPhases = phase.miniPhases
111-
val previousSubPhases = previousPhases(subPhases.toList)
112-
if (previousSubPhases.length == subPhases.length) previousSubPhases ::: previousPhases(phases1)
113-
else previousSubPhases
114-
case phase :: phases1 if phase ne ctx.phase =>
115-
phase :: previousPhases(phases1)
116-
case _ =>
117-
Nil
118-
}
119-
120108
def check(phasesToRun: Seq[Phase], ctx: Context): Tree = {
121109
val fusedPhase = ctx.phase.prevMega(using ctx)
122110
report.echo(s"checking ${ctx.compilationUnit} after phase ${fusedPhase}")(using ctx)
@@ -219,7 +207,7 @@ object TreeChecker {
219207
class Checker(phasesToCheck: Seq[Phase]) extends ReTyper with Checking {
220208
import ast.tpd._
221209

222-
private val nowDefinedSyms = util.HashSet[Symbol]()
210+
protected val nowDefinedSyms = util.HashSet[Symbol]()
223211
private val patBoundSyms = util.HashSet[Symbol]()
224212
private val everDefinedSyms = MutableSymbolMap[untpd.Tree]()
225213

@@ -724,4 +712,52 @@ object TreeChecker {
724712

725713
override def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = tree
726714
}
715+
716+
/** Tree checker that can be applied to a local tree. */
717+
class LocalChecker(phasesToCheck: Seq[Phase]) extends Checker(phasesToCheck: Seq[Phase]):
718+
override def assertDefined(tree: untpd.Tree)(using Context): Unit =
719+
// Only check definitions nested in the local tree
720+
if nowDefinedSyms.contains(tree.symbol.maybeOwner) then
721+
super.assertDefined(tree)
722+
723+
def checkMacroGeneratedTree(original: tpd.Tree, expansion: tpd.Tree)(using Context): Unit =
724+
if ctx.settings.XcheckMacros.value then
725+
val checkingCtx = ctx
726+
.fresh
727+
.addMode(Mode.ImplicitsEnabled)
728+
.setReporter(new ThrowingReporter(ctx.reporter))
729+
val phases = ctx.base.allPhases.toList
730+
val treeChecker = new LocalChecker(previousPhases(phases))
731+
732+
try treeChecker.typed(expansion)(using checkingCtx)
733+
catch
734+
case err: java.lang.AssertionError =>
735+
report.error(
736+
s"""Malformed tree was found while expanding macro with -Xcheck-macros.
737+
|The tree does not conform to the compiler's tree invariants.
738+
|
739+
|Macro was:
740+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
741+
|
742+
|The macro returned:
743+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
744+
|
745+
|Error:
746+
|${err.getMessage}
747+
|
748+
|""",
749+
original
750+
)
751+
752+
private[TreeChecker] def previousPhases(phases: List[Phase])(using Context): List[Phase] = phases match {
753+
case (phase: MegaPhase) :: phases1 =>
754+
val subPhases = phase.miniPhases
755+
val previousSubPhases = previousPhases(subPhases.toList)
756+
if (previousSubPhases.length == subPhases.length) previousSubPhases ::: previousPhases(phases1)
757+
else previousSubPhases
758+
case phase :: phases1 if phase ne ctx.phase =>
759+
phase :: previousPhases(phases1)
760+
case _ =>
761+
Nil
762+
}
727763
}

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
298298

299299
object ValDef extends ValDefModule:
300300
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
301-
tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree))
301+
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
302302
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
303303
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
304304
def unapply(vdef: ValDef): (String, TypeTree, Option[Term]) =
@@ -1474,7 +1474,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
14741474

14751475
object Bind extends BindModule:
14761476
def apply(sym: Symbol, pattern: Tree): Bind =
1477-
tpd.Bind(sym, pattern)
1477+
withDefaultPos(tpd.Bind(sym, pattern))
14781478
def copy(original: Tree)(name: String, pattern: Tree): Bind =
14791479
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))
14801480
def unapply(pattern: Bind): (String, Tree) =
@@ -2395,7 +2395,12 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
23952395

23962396
object Implicits extends ImplicitsModule:
23972397
def search(tpe: TypeRepr): ImplicitSearchResult =
2398-
ctx.typer.inferImplicitArg(tpe, Position.ofMacroExpansion.span)
2398+
import tpd.TreeOps
2399+
val implicitTree = ctx.typer.inferImplicitArg(tpe, Position.ofMacroExpansion.span)
2400+
// Make sure that we do not have any uninstantiated type variables.
2401+
// See tests/pos-macros/exprSummonWithTypeVar with -Xcheck-macros.
2402+
implicitTree.foreachSubTree(tree => dotc.typer.Inferencing.fullyDefinedType(tree.tpe, "", tree))
2403+
implicitTree
23992404
end Implicits
24002405

24012406
type ImplicitSearchResult = Tree
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.compiletime.{erasedValue, summonFrom}
2+
3+
import scala.quoted._
4+
5+
inline given summonAfterTypeMatch[T]: Any =
6+
${ summonAfterTypeMatchExpr[T] }
7+
8+
private def summonAfterTypeMatchExpr[T: Type](using Quotes): Expr[Any] =
9+
Expr.summon[Foo[T]].get
10+
11+
trait Foo[T]
12+
13+
given IntFoo[T <: Int]: Foo[T] = ???
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
def test: Unit = summonAfterTypeMatch[Int]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import scala.quoted.*
2+
3+
inline def f[T](inline code: =>T): Any =
4+
${ create[T]('{ () => code }) }
5+
6+
def create[T: Type](code: Expr[() => T])(using Quotes): Expr[Any] =
7+
'{ identity($code) }
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
def test: Unit = f[Unit](???)

0 commit comments

Comments
 (0)