Skip to content

Commit 0015a79

Browse files
Fix #7011: check possibly side-effecting transform
The transform function in question can be overridden with a version that produces side effects. If that is the case, and if the superclass transform function is called from the overridden transform function, the overridden transform function might be executed twice with all its side effects. This commit makes sure the transform function is not called twice on the same input. An example where it can go wrong is: https://github.com/lampepfl/dotty/blob/9a4b7d39a595dba3c0baf340b0bf911844fcae69/compiler/src/dotty/tools/dotc/typer/Typer.scala#L1080 Here, the transform function performs the side effect of entering a symbol into scope. Furthermore, if the symbol already exists in scope, it emits an error. Hence, if we call this function twice on the same argument, we are guaranteed to get an error.
1 parent 963719e commit 0015a79

File tree

3 files changed

+123
-105
lines changed

3 files changed

+123
-105
lines changed

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 106 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,111 +1238,112 @@ object Trees {
12381238
protected def inlineContext(call: Tree)(implicit ctx: Context): Context = ctx
12391239

12401240
abstract class TreeMap(val cpy: TreeCopier = inst.cpy) { self =>
1241-
1242-
def transform(tree: Tree)(implicit ctx: Context): Tree =
1243-
if (tree.source != ctx.source && tree.source.exists)
1244-
transform(tree)(ctx.withSource(tree.source))
1245-
else {
1246-
Stats.record(s"TreeMap.transform/$getClass")
1247-
def localCtx =
1248-
if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
1249-
1250-
if (skipTransform(tree)) tree
1251-
else tree match {
1252-
case Ident(name) =>
1253-
tree
1254-
case Select(qualifier, name) =>
1255-
cpy.Select(tree)(transform(qualifier), name)
1256-
case This(qual) =>
1257-
tree
1258-
case Super(qual, mix) =>
1259-
cpy.Super(tree)(transform(qual), mix)
1260-
case Apply(fun, args) =>
1261-
cpy.Apply(tree)(transform(fun), transform(args))
1262-
case TypeApply(fun, args) =>
1263-
cpy.TypeApply(tree)(transform(fun), transform(args))
1264-
case Literal(const) =>
1265-
tree
1266-
case New(tpt) =>
1267-
cpy.New(tree)(transform(tpt))
1268-
case Typed(expr, tpt) =>
1269-
cpy.Typed(tree)(transform(expr), transform(tpt))
1270-
case NamedArg(name, arg) =>
1271-
cpy.NamedArg(tree)(name, transform(arg))
1272-
case Assign(lhs, rhs) =>
1273-
cpy.Assign(tree)(transform(lhs), transform(rhs))
1274-
case Block(stats, expr) =>
1275-
cpy.Block(tree)(transformStats(stats), transform(expr))
1276-
case If(cond, thenp, elsep) =>
1277-
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
1278-
case Closure(env, meth, tpt) =>
1279-
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
1280-
case Match(selector, cases) =>
1281-
cpy.Match(tree)(transform(selector), transformSub(cases))
1282-
case CaseDef(pat, guard, body) =>
1283-
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
1284-
case Labeled(bind, expr) =>
1285-
cpy.Labeled(tree)(transformSub(bind), transform(expr))
1286-
case Return(expr, from) =>
1287-
cpy.Return(tree)(transform(expr), transformSub(from))
1288-
case WhileDo(cond, body) =>
1289-
cpy.WhileDo(tree)(transform(cond), transform(body))
1290-
case Try(block, cases, finalizer) =>
1291-
cpy.Try(tree)(transform(block), transformSub(cases), transform(finalizer))
1292-
case SeqLiteral(elems, elemtpt) =>
1293-
cpy.SeqLiteral(tree)(transform(elems), transform(elemtpt))
1294-
case Inlined(call, bindings, expansion) =>
1295-
cpy.Inlined(tree)(call, transformSub(bindings), transform(expansion)(inlineContext(call)))
1296-
case TypeTree() =>
1297-
tree
1298-
case SingletonTypeTree(ref) =>
1299-
cpy.SingletonTypeTree(tree)(transform(ref))
1300-
case RefinedTypeTree(tpt, refinements) =>
1301-
cpy.RefinedTypeTree(tree)(transform(tpt), transformSub(refinements))
1302-
case AppliedTypeTree(tpt, args) =>
1303-
cpy.AppliedTypeTree(tree)(transform(tpt), transform(args))
1304-
case LambdaTypeTree(tparams, body) =>
1305-
implicit val ctx = localCtx
1306-
cpy.LambdaTypeTree(tree)(transformSub(tparams), transform(body))
1307-
case MatchTypeTree(bound, selector, cases) =>
1308-
cpy.MatchTypeTree(tree)(transform(bound), transform(selector), transformSub(cases))
1309-
case ByNameTypeTree(result) =>
1310-
cpy.ByNameTypeTree(tree)(transform(result))
1311-
case TypeBoundsTree(lo, hi) =>
1312-
cpy.TypeBoundsTree(tree)(transform(lo), transform(hi))
1313-
case Bind(name, body) =>
1314-
cpy.Bind(tree)(name, transform(body))
1315-
case Alternative(trees) =>
1316-
cpy.Alternative(tree)(transform(trees))
1317-
case UnApply(fun, implicits, patterns) =>
1318-
cpy.UnApply(tree)(transform(fun), transform(implicits), transform(patterns))
1319-
case EmptyValDef =>
1320-
tree
1321-
case tree @ ValDef(name, tpt, _) =>
1322-
implicit val ctx = localCtx
1323-
val tpt1 = transform(tpt)
1324-
val rhs1 = transform(tree.rhs)
1325-
cpy.ValDef(tree)(name, tpt1, rhs1)
1326-
case tree @ DefDef(name, tparams, vparamss, tpt, _) =>
1327-
implicit val ctx = localCtx
1328-
cpy.DefDef(tree)(name, transformSub(tparams), vparamss mapConserve (transformSub(_)), transform(tpt), transform(tree.rhs))
1329-
case tree @ TypeDef(name, rhs) =>
1330-
implicit val ctx = localCtx
1331-
cpy.TypeDef(tree)(name, transform(rhs))
1332-
case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty =>
1333-
cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body))
1334-
case Import(importGiven, expr, selectors) =>
1335-
cpy.Import(tree)(importGiven, transform(expr), selectors)
1336-
case PackageDef(pid, stats) =>
1337-
cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(localCtx))
1338-
case Annotated(arg, annot) =>
1339-
cpy.Annotated(tree)(transform(arg), transform(annot))
1340-
case Thicket(trees) =>
1341-
val trees1 = transform(trees)
1342-
if (trees1 eq trees) tree else Thicket(trees1)
1343-
case _ =>
1344-
transformMoreCases(tree)
1345-
}
1241+
def transform(tree: Tree)(implicit ctxLowPrio: Context): Tree = {
1242+
implicit val ctx: Context =
1243+
if (tree.source != ctxLowPrio.source && tree.source.exists)
1244+
ctxLowPrio.withSource(tree.source)
1245+
else ctxLowPrio
1246+
1247+
Stats.record(s"TreeMap.transform/$getClass")
1248+
def localCtx =
1249+
if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
1250+
1251+
if (skipTransform(tree)) tree
1252+
else tree match {
1253+
case Ident(name) =>
1254+
tree
1255+
case Select(qualifier, name) =>
1256+
cpy.Select(tree)(transform(qualifier), name)
1257+
case This(qual) =>
1258+
tree
1259+
case Super(qual, mix) =>
1260+
cpy.Super(tree)(transform(qual), mix)
1261+
case Apply(fun, args) =>
1262+
cpy.Apply(tree)(transform(fun), transform(args))
1263+
case TypeApply(fun, args) =>
1264+
cpy.TypeApply(tree)(transform(fun), transform(args))
1265+
case Literal(const) =>
1266+
tree
1267+
case New(tpt) =>
1268+
cpy.New(tree)(transform(tpt))
1269+
case Typed(expr, tpt) =>
1270+
cpy.Typed(tree)(transform(expr), transform(tpt))
1271+
case NamedArg(name, arg) =>
1272+
cpy.NamedArg(tree)(name, transform(arg))
1273+
case Assign(lhs, rhs) =>
1274+
cpy.Assign(tree)(transform(lhs), transform(rhs))
1275+
case Block(stats, expr) =>
1276+
cpy.Block(tree)(transformStats(stats), transform(expr))
1277+
case If(cond, thenp, elsep) =>
1278+
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
1279+
case Closure(env, meth, tpt) =>
1280+
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
1281+
case Match(selector, cases) =>
1282+
cpy.Match(tree)(transform(selector), transformSub(cases))
1283+
case CaseDef(pat, guard, body) =>
1284+
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
1285+
case Labeled(bind, expr) =>
1286+
cpy.Labeled(tree)(transformSub(bind), transform(expr))
1287+
case Return(expr, from) =>
1288+
cpy.Return(tree)(transform(expr), transformSub(from))
1289+
case WhileDo(cond, body) =>
1290+
cpy.WhileDo(tree)(transform(cond), transform(body))
1291+
case Try(block, cases, finalizer) =>
1292+
cpy.Try(tree)(transform(block), transformSub(cases), transform(finalizer))
1293+
case SeqLiteral(elems, elemtpt) =>
1294+
cpy.SeqLiteral(tree)(transform(elems), transform(elemtpt))
1295+
case Inlined(call, bindings, expansion) =>
1296+
cpy.Inlined(tree)(call, transformSub(bindings), transform(expansion)(inlineContext(call)))
1297+
case TypeTree() =>
1298+
tree
1299+
case SingletonTypeTree(ref) =>
1300+
cpy.SingletonTypeTree(tree)(transform(ref))
1301+
case RefinedTypeTree(tpt, refinements) =>
1302+
cpy.RefinedTypeTree(tree)(transform(tpt), transformSub(refinements))
1303+
case AppliedTypeTree(tpt, args) =>
1304+
cpy.AppliedTypeTree(tree)(transform(tpt), transform(args))
1305+
case LambdaTypeTree(tparams, body) =>
1306+
implicit val ctx = localCtx
1307+
cpy.LambdaTypeTree(tree)(transformSub(tparams), transform(body))
1308+
case MatchTypeTree(bound, selector, cases) =>
1309+
cpy.MatchTypeTree(tree)(transform(bound), transform(selector), transformSub(cases))
1310+
case ByNameTypeTree(result) =>
1311+
cpy.ByNameTypeTree(tree)(transform(result))
1312+
case TypeBoundsTree(lo, hi) =>
1313+
cpy.TypeBoundsTree(tree)(transform(lo), transform(hi))
1314+
case Bind(name, body) =>
1315+
cpy.Bind(tree)(name, transform(body))
1316+
case Alternative(trees) =>
1317+
cpy.Alternative(tree)(transform(trees))
1318+
case UnApply(fun, implicits, patterns) =>
1319+
cpy.UnApply(tree)(transform(fun), transform(implicits), transform(patterns))
1320+
case EmptyValDef =>
1321+
tree
1322+
case tree @ ValDef(name, tpt, _) =>
1323+
implicit val ctx = localCtx
1324+
val tpt1 = transform(tpt)
1325+
val rhs1 = transform(tree.rhs)
1326+
cpy.ValDef(tree)(name, tpt1, rhs1)
1327+
case tree @ DefDef(name, tparams, vparamss, tpt, _) =>
1328+
implicit val ctx = localCtx
1329+
cpy.DefDef(tree)(name, transformSub(tparams), vparamss mapConserve (transformSub(_)), transform(tpt), transform(tree.rhs))
1330+
case tree @ TypeDef(name, rhs) =>
1331+
implicit val ctx = localCtx
1332+
cpy.TypeDef(tree)(name, transform(rhs))
1333+
case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty =>
1334+
cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body))
1335+
case Import(importGiven, expr, selectors) =>
1336+
cpy.Import(tree)(importGiven, transform(expr), selectors)
1337+
case PackageDef(pid, stats) =>
1338+
cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(localCtx))
1339+
case Annotated(arg, annot) =>
1340+
cpy.Annotated(tree)(transform(arg), transform(annot))
1341+
case Thicket(trees) =>
1342+
val trees1 = transform(trees)
1343+
if (trees1 eq trees) tree else Thicket(trees1)
1344+
case _ =>
1345+
transformMoreCases(tree)
1346+
}
13461347
}
13471348

13481349
def transformStats(trees: List[Tree])(implicit ctx: Context): List[Tree] =

tests/pos/i7011/Macros_1.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._, scala.quoted.matching._
2+
import delegate scala.quoted._
3+
4+
inline def mcr(body: => Any): Unit = ${mcrImpl('body)}
5+
6+
def mcrImpl[T](body: Expr[Any]) given (ctx: QuoteContext): Expr[Any] = {
7+
import ctx.tasty._
8+
9+
val bTree = body.unseal
10+
val under = bTree.underlyingArgument
11+
12+
val res = '{Box(${under.asInstanceOf[Term].seal})}
13+
res
14+
}
15+
16+
class Box(inner: => Any)

tests/pos/i7011/Test_2.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
def f = mcr { try () catch { case x => } }

0 commit comments

Comments
 (0)