From 0015a79f701ff29dbd98c6b6f004906ab1017d78 Mon Sep 17 00:00:00 2001 From: Anatolii Date: Thu, 8 Aug 2019 16:34:34 +0200 Subject: [PATCH] Fix #7011: check possibly side-effecting transform The transform function in question can be overridden with a version that produces side effects. If that is the case, and if the superclass transform function is called from the overridden transform function, the overridden transform function might be executed twice with all its side effects. This commit makes sure the transform function is not called twice on the same input. An example where it can go wrong is: https://github.com/lampepfl/dotty/blob/9a4b7d39a595dba3c0baf340b0bf911844fcae69/compiler/src/dotty/tools/dotc/typer/Typer.scala#L1080 Here, the transform function performs the side effect of entering a symbol into scope. Furthermore, if the symbol already exists in scope, it emits an error. Hence, if we call this function twice on the same argument, we are guaranteed to get an error. --- compiler/src/dotty/tools/dotc/ast/Trees.scala | 211 +++++++++--------- tests/pos/i7011/Macros_1.scala | 16 ++ tests/pos/i7011/Test_2.scala | 1 + 3 files changed, 123 insertions(+), 105 deletions(-) create mode 100644 tests/pos/i7011/Macros_1.scala create mode 100644 tests/pos/i7011/Test_2.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index b5c64669c19e..1ec6b264a04b 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -1238,111 +1238,112 @@ object Trees { protected def inlineContext(call: Tree)(implicit ctx: Context): Context = ctx abstract class TreeMap(val cpy: TreeCopier = inst.cpy) { self => - - def transform(tree: Tree)(implicit ctx: Context): Tree = - if (tree.source != ctx.source && tree.source.exists) - transform(tree)(ctx.withSource(tree.source)) - else { - Stats.record(s"TreeMap.transform/$getClass") - def localCtx = - if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx - - if (skipTransform(tree)) tree - else tree match { - case Ident(name) => - tree - case Select(qualifier, name) => - cpy.Select(tree)(transform(qualifier), name) - case This(qual) => - tree - case Super(qual, mix) => - cpy.Super(tree)(transform(qual), mix) - case Apply(fun, args) => - cpy.Apply(tree)(transform(fun), transform(args)) - case TypeApply(fun, args) => - cpy.TypeApply(tree)(transform(fun), transform(args)) - case Literal(const) => - tree - case New(tpt) => - cpy.New(tree)(transform(tpt)) - case Typed(expr, tpt) => - cpy.Typed(tree)(transform(expr), transform(tpt)) - case NamedArg(name, arg) => - cpy.NamedArg(tree)(name, transform(arg)) - case Assign(lhs, rhs) => - cpy.Assign(tree)(transform(lhs), transform(rhs)) - case Block(stats, expr) => - cpy.Block(tree)(transformStats(stats), transform(expr)) - case If(cond, thenp, elsep) => - cpy.If(tree)(transform(cond), transform(thenp), transform(elsep)) - case Closure(env, meth, tpt) => - cpy.Closure(tree)(transform(env), transform(meth), transform(tpt)) - case Match(selector, cases) => - cpy.Match(tree)(transform(selector), transformSub(cases)) - case CaseDef(pat, guard, body) => - cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body)) - case Labeled(bind, expr) => - cpy.Labeled(tree)(transformSub(bind), transform(expr)) - case Return(expr, from) => - cpy.Return(tree)(transform(expr), transformSub(from)) - case WhileDo(cond, body) => - cpy.WhileDo(tree)(transform(cond), transform(body)) - case Try(block, cases, finalizer) => - cpy.Try(tree)(transform(block), transformSub(cases), transform(finalizer)) - case SeqLiteral(elems, elemtpt) => - cpy.SeqLiteral(tree)(transform(elems), transform(elemtpt)) - case Inlined(call, bindings, expansion) => - cpy.Inlined(tree)(call, transformSub(bindings), transform(expansion)(inlineContext(call))) - case TypeTree() => - tree - case SingletonTypeTree(ref) => - cpy.SingletonTypeTree(tree)(transform(ref)) - case RefinedTypeTree(tpt, refinements) => - cpy.RefinedTypeTree(tree)(transform(tpt), transformSub(refinements)) - case AppliedTypeTree(tpt, args) => - cpy.AppliedTypeTree(tree)(transform(tpt), transform(args)) - case LambdaTypeTree(tparams, body) => - implicit val ctx = localCtx - cpy.LambdaTypeTree(tree)(transformSub(tparams), transform(body)) - case MatchTypeTree(bound, selector, cases) => - cpy.MatchTypeTree(tree)(transform(bound), transform(selector), transformSub(cases)) - case ByNameTypeTree(result) => - cpy.ByNameTypeTree(tree)(transform(result)) - case TypeBoundsTree(lo, hi) => - cpy.TypeBoundsTree(tree)(transform(lo), transform(hi)) - case Bind(name, body) => - cpy.Bind(tree)(name, transform(body)) - case Alternative(trees) => - cpy.Alternative(tree)(transform(trees)) - case UnApply(fun, implicits, patterns) => - cpy.UnApply(tree)(transform(fun), transform(implicits), transform(patterns)) - case EmptyValDef => - tree - case tree @ ValDef(name, tpt, _) => - implicit val ctx = localCtx - val tpt1 = transform(tpt) - val rhs1 = transform(tree.rhs) - cpy.ValDef(tree)(name, tpt1, rhs1) - case tree @ DefDef(name, tparams, vparamss, tpt, _) => - implicit val ctx = localCtx - cpy.DefDef(tree)(name, transformSub(tparams), vparamss mapConserve (transformSub(_)), transform(tpt), transform(tree.rhs)) - case tree @ TypeDef(name, rhs) => - implicit val ctx = localCtx - cpy.TypeDef(tree)(name, transform(rhs)) - case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty => - cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body)) - case Import(importGiven, expr, selectors) => - cpy.Import(tree)(importGiven, transform(expr), selectors) - case PackageDef(pid, stats) => - cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(localCtx)) - case Annotated(arg, annot) => - cpy.Annotated(tree)(transform(arg), transform(annot)) - case Thicket(trees) => - val trees1 = transform(trees) - if (trees1 eq trees) tree else Thicket(trees1) - case _ => - transformMoreCases(tree) - } + def transform(tree: Tree)(implicit ctxLowPrio: Context): Tree = { + implicit val ctx: Context = + if (tree.source != ctxLowPrio.source && tree.source.exists) + ctxLowPrio.withSource(tree.source) + else ctxLowPrio + + Stats.record(s"TreeMap.transform/$getClass") + def localCtx = + if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx + + if (skipTransform(tree)) tree + else tree match { + case Ident(name) => + tree + case Select(qualifier, name) => + cpy.Select(tree)(transform(qualifier), name) + case This(qual) => + tree + case Super(qual, mix) => + cpy.Super(tree)(transform(qual), mix) + case Apply(fun, args) => + cpy.Apply(tree)(transform(fun), transform(args)) + case TypeApply(fun, args) => + cpy.TypeApply(tree)(transform(fun), transform(args)) + case Literal(const) => + tree + case New(tpt) => + cpy.New(tree)(transform(tpt)) + case Typed(expr, tpt) => + cpy.Typed(tree)(transform(expr), transform(tpt)) + case NamedArg(name, arg) => + cpy.NamedArg(tree)(name, transform(arg)) + case Assign(lhs, rhs) => + cpy.Assign(tree)(transform(lhs), transform(rhs)) + case Block(stats, expr) => + cpy.Block(tree)(transformStats(stats), transform(expr)) + case If(cond, thenp, elsep) => + cpy.If(tree)(transform(cond), transform(thenp), transform(elsep)) + case Closure(env, meth, tpt) => + cpy.Closure(tree)(transform(env), transform(meth), transform(tpt)) + case Match(selector, cases) => + cpy.Match(tree)(transform(selector), transformSub(cases)) + case CaseDef(pat, guard, body) => + cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body)) + case Labeled(bind, expr) => + cpy.Labeled(tree)(transformSub(bind), transform(expr)) + case Return(expr, from) => + cpy.Return(tree)(transform(expr), transformSub(from)) + case WhileDo(cond, body) => + cpy.WhileDo(tree)(transform(cond), transform(body)) + case Try(block, cases, finalizer) => + cpy.Try(tree)(transform(block), transformSub(cases), transform(finalizer)) + case SeqLiteral(elems, elemtpt) => + cpy.SeqLiteral(tree)(transform(elems), transform(elemtpt)) + case Inlined(call, bindings, expansion) => + cpy.Inlined(tree)(call, transformSub(bindings), transform(expansion)(inlineContext(call))) + case TypeTree() => + tree + case SingletonTypeTree(ref) => + cpy.SingletonTypeTree(tree)(transform(ref)) + case RefinedTypeTree(tpt, refinements) => + cpy.RefinedTypeTree(tree)(transform(tpt), transformSub(refinements)) + case AppliedTypeTree(tpt, args) => + cpy.AppliedTypeTree(tree)(transform(tpt), transform(args)) + case LambdaTypeTree(tparams, body) => + implicit val ctx = localCtx + cpy.LambdaTypeTree(tree)(transformSub(tparams), transform(body)) + case MatchTypeTree(bound, selector, cases) => + cpy.MatchTypeTree(tree)(transform(bound), transform(selector), transformSub(cases)) + case ByNameTypeTree(result) => + cpy.ByNameTypeTree(tree)(transform(result)) + case TypeBoundsTree(lo, hi) => + cpy.TypeBoundsTree(tree)(transform(lo), transform(hi)) + case Bind(name, body) => + cpy.Bind(tree)(name, transform(body)) + case Alternative(trees) => + cpy.Alternative(tree)(transform(trees)) + case UnApply(fun, implicits, patterns) => + cpy.UnApply(tree)(transform(fun), transform(implicits), transform(patterns)) + case EmptyValDef => + tree + case tree @ ValDef(name, tpt, _) => + implicit val ctx = localCtx + val tpt1 = transform(tpt) + val rhs1 = transform(tree.rhs) + cpy.ValDef(tree)(name, tpt1, rhs1) + case tree @ DefDef(name, tparams, vparamss, tpt, _) => + implicit val ctx = localCtx + cpy.DefDef(tree)(name, transformSub(tparams), vparamss mapConserve (transformSub(_)), transform(tpt), transform(tree.rhs)) + case tree @ TypeDef(name, rhs) => + implicit val ctx = localCtx + cpy.TypeDef(tree)(name, transform(rhs)) + case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty => + cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body)) + case Import(importGiven, expr, selectors) => + cpy.Import(tree)(importGiven, transform(expr), selectors) + case PackageDef(pid, stats) => + cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(localCtx)) + case Annotated(arg, annot) => + cpy.Annotated(tree)(transform(arg), transform(annot)) + case Thicket(trees) => + val trees1 = transform(trees) + if (trees1 eq trees) tree else Thicket(trees1) + case _ => + transformMoreCases(tree) + } } def transformStats(trees: List[Tree])(implicit ctx: Context): List[Tree] = diff --git a/tests/pos/i7011/Macros_1.scala b/tests/pos/i7011/Macros_1.scala new file mode 100644 index 000000000000..c601209c381d --- /dev/null +++ b/tests/pos/i7011/Macros_1.scala @@ -0,0 +1,16 @@ +import scala.quoted._, scala.quoted.matching._ +import delegate scala.quoted._ + +inline def mcr(body: => Any): Unit = ${mcrImpl('body)} + +def mcrImpl[T](body: Expr[Any]) given (ctx: QuoteContext): Expr[Any] = { + import ctx.tasty._ + + val bTree = body.unseal + val under = bTree.underlyingArgument + + val res = '{Box(${under.asInstanceOf[Term].seal})} + res +} + +class Box(inner: => Any) diff --git a/tests/pos/i7011/Test_2.scala b/tests/pos/i7011/Test_2.scala new file mode 100644 index 000000000000..e79373c07c21 --- /dev/null +++ b/tests/pos/i7011/Test_2.scala @@ -0,0 +1 @@ +def f = mcr { try () catch { case x => } }