Skip to content

Commit f94710c

Browse files
committed
Refactor interpreter
Extract the general purpose logic from the splice interpreter. The new interpreter class will be the basis for the macro annotation interpreter. The splice interpreter only keeps logic related with level -1 quote and type evaluation. Part of https://github.com/dotty-staging/dotty/tree/design-macro-annotations
1 parent bf808b3 commit f94710c

File tree

2 files changed

+408
-333
lines changed

2 files changed

+408
-333
lines changed
Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import scala.language.unsafeNulls
5+
6+
import java.io.{PrintWriter, StringWriter}
7+
import java.lang.reflect.{InvocationTargetException, Method => JLRMethod}
8+
9+
import core._
10+
import Decorators._
11+
import Flags._
12+
import Types._
13+
import Contexts._
14+
import Symbols._
15+
import Constants._
16+
import ast.Trees._
17+
import ast.TreeTypeMap
18+
import util.Spans._
19+
import SymUtils._
20+
import NameKinds._
21+
import dotty.tools.dotc.ast.tpd
22+
import typer.Implicits.SearchFailureType
23+
import SymDenotations.NoDenotation
24+
25+
import scala.collection.mutable
26+
import dotty.tools.dotc.core.Annotations._
27+
import dotty.tools.dotc.core.Names._
28+
import dotty.tools.dotc.core.StdNames._
29+
import dotty.tools.dotc.core.StagingContext._
30+
import dotty.tools.dotc.quoted._
31+
import dotty.tools.dotc.transform.TreeMapWithStages._
32+
import dotty.tools.dotc.typer.ImportInfo.withRootImports
33+
import dotty.tools.dotc.ast.TreeMapWithImplicits
34+
35+
import dotty.tools.dotc.ast.tpd
36+
import dotty.tools.dotc.ast.Trees._
37+
import dotty.tools.dotc.core.Contexts._
38+
import dotty.tools.dotc.core.Decorators._
39+
import dotty.tools.dotc.core.Flags._
40+
import dotty.tools.dotc.core.NameKinds.FlatName
41+
import dotty.tools.dotc.core.Names.{Name, TermName}
42+
import dotty.tools.dotc.core.StdNames._
43+
import dotty.tools.dotc.core.Types._
44+
import dotty.tools.dotc.core.Symbols._
45+
import dotty.tools.dotc.core.Denotations.staticRef
46+
import dotty.tools.dotc.core.{NameKinds, TypeErasure}
47+
import dotty.tools.dotc.core.Constants.Constant
48+
49+
import scala.util.control.NonFatal
50+
import dotty.tools.dotc.util.SrcPos
51+
import dotty.tools.repl.AbstractFileClassLoader
52+
53+
import scala.annotation.constructorOnly
54+
import scala.annotation.tailrec
55+
import scala.collection.mutable.ListBuffer
56+
57+
import scala.quoted.runtime.impl.QuotesImpl
58+
59+
import scala.reflect.ClassTag
60+
61+
/** List of classes of the parameters of the signature of `sym` */
62+
abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context) {
63+
import Interpreter._
64+
import tpd._
65+
66+
type Env = Map[Symbol, Object]
67+
68+
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
69+
* Return Some of the result or None if some error happen during the interpretation.
70+
*/
71+
final def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] =
72+
interpretTree(tree)(Map.empty) match {
73+
case obj: T => Some(obj)
74+
case obj =>
75+
// TODO upgrade to a full type tag check or something similar
76+
report.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
77+
None
78+
}
79+
80+
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments. */
81+
protected def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
82+
case Literal(Constant(value)) =>
83+
interpretLiteral(value)
84+
85+
case tree: Ident if tree.symbol.is(Inline, butNot = Method) =>
86+
tree.tpe.widenTermRefExpr match
87+
case ConstantType(c) => c.value.asInstanceOf[Object]
88+
case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos)
89+
90+
// TODO disallow interpreted method calls as arguments
91+
case Call(fn, args) =>
92+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package))
93+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
94+
else if (fn.symbol.is(Module))
95+
interpretModuleAccess(fn.symbol)
96+
else if (fn.symbol.is(Method) && fn.symbol.isStatic) {
97+
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
98+
staticMethodCall(interpretArgs(args, fn.symbol.info))
99+
}
100+
else if fn.symbol.isStatic then
101+
assert(args.isEmpty)
102+
interpretedStaticFieldAccess(fn.symbol)
103+
else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic)
104+
if (fn.name == nme.asInstanceOfPM)
105+
interpretModuleAccess(fn.qualifier.symbol)
106+
else {
107+
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
108+
staticMethodCall(interpretArgs(args, fn.symbol.info))
109+
}
110+
else if (env.contains(fn.symbol))
111+
env(fn.symbol)
112+
else if (tree.symbol.is(InlineProxy))
113+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
114+
else
115+
unexpectedTree(tree)
116+
117+
case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) =>
118+
(obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj))
119+
120+
// Interpret `foo(j = x, i = y)` which it is expanded to
121+
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
122+
case Block(stats, expr) => interpretBlock(stats, expr)
123+
case NamedArg(_, arg) => interpretTree(arg)
124+
125+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
126+
127+
case Typed(expr, _) =>
128+
interpretTree(expr)
129+
130+
case SeqLiteral(elems, _) =>
131+
interpretVarargs(elems.map(e => interpretTree(e)))
132+
133+
case _ =>
134+
unexpectedTree(tree)
135+
}
136+
137+
private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = {
138+
def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] =
139+
assert(args.size == argTypes.size)
140+
val view =
141+
for (arg, info) <- args.lazyZip(argTypes) yield
142+
info match
143+
case _: ExprType => () => interpretTree(arg) // by-name argument
144+
case _ => interpretTree(arg) // by-value argument
145+
view.toList
146+
147+
fnType.dealias match
148+
case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType)
149+
case fnType: MethodType =>
150+
val argTypes = fnType.paramInfos
151+
assert(argss.head.size == argTypes.size)
152+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType)
153+
case fnType: AppliedType if defn.isContextFunctionType(fnType) =>
154+
val argTypes :+ resType = fnType.args: @unchecked
155+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType)
156+
case fnType: PolyType => interpretArgs(argss, fnType.resType)
157+
case fnType: ExprType => interpretArgs(argss, fnType.resType)
158+
case _ =>
159+
assert(argss.isEmpty)
160+
Nil
161+
}
162+
163+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
164+
var unexpected: Option[Object] = None
165+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
166+
case stat: ValDef =>
167+
accEnv.updated(stat.symbol, interpretTree(stat.rhs)(accEnv))
168+
case stat =>
169+
if (unexpected.isEmpty)
170+
unexpected = Some(unexpectedTree(stat))
171+
accEnv
172+
})
173+
unexpected.getOrElse(interpretTree(expr)(newEnv))
174+
}
175+
176+
private def interpretLiteral(value: Any)(implicit env: Env): Object =
177+
value.asInstanceOf[Object]
178+
179+
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
180+
args.toSeq
181+
182+
private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = {
183+
val (inst, clazz) =
184+
try
185+
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE))
186+
(null, loadReplLineClass(moduleClass))
187+
else {
188+
val inst = loadModule(moduleClass)
189+
(inst, inst.getClass)
190+
}
191+
catch
192+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
193+
if (ctx.settings.XprintSuspension.value)
194+
report.echo(i"suspension triggered by a dependency on $sym", pos)
195+
ctx.compilationUnit.suspend() // this throws a SuspendException
196+
197+
val name = fn.name.asTermName
198+
val method = getMethod(clazz, name, paramsSig(fn))
199+
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method)
200+
}
201+
202+
private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = {
203+
val clazz = loadClass(sym.owner.fullName.toString)
204+
val field = clazz.getField(sym.name.toString)
205+
field.get(null)
206+
}
207+
208+
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
209+
loadModule(fn.moduleClass)
210+
211+
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
212+
val clazz = loadClass(fn.owner.fullName.toString)
213+
val constr = clazz.getConstructor(paramsSig(fn): _*)
214+
constr.newInstance(args: _*).asInstanceOf[Object]
215+
}
216+
217+
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
218+
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.srcPos)
219+
220+
private def loadModule(sym: Symbol): Object =
221+
if (sym.owner.is(Package)) {
222+
// is top level object
223+
val moduleClass = loadClass(sym.fullName.toString)
224+
moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null)
225+
}
226+
else {
227+
// nested object in an object
228+
val className = {
229+
val pack = sym.topLevelClass.owner
230+
if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString
231+
else pack.showFullName + "." + sym.flatName
232+
}
233+
val clazz = loadClass(className)
234+
clazz.getConstructor().newInstance().asInstanceOf[Object]
235+
}
236+
237+
private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[?] = {
238+
val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader)
239+
lineClassloader.loadClass(moduleClass.name.firstPart.toString)
240+
}
241+
242+
private def loadClass(name: String): Class[?] =
243+
try classLoader.loadClass(name)
244+
catch {
245+
case _: ClassNotFoundException if ctx.compilationUnit.isSuspendable =>
246+
if (ctx.settings.XprintSuspension.value)
247+
report.echo(i"suspension triggered by a dependency on $name", pos)
248+
ctx.compilationUnit.suspend()
249+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
250+
if (ctx.settings.XprintSuspension.value)
251+
report.echo(i"suspension triggered by a dependency on $sym", pos)
252+
ctx.compilationUnit.suspend() // this throws a SuspendException
253+
}
254+
255+
private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod =
256+
try clazz.getMethod(name.toString, paramClasses: _*)
257+
catch {
258+
case _: NoSuchMethodException =>
259+
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)"
260+
throw new StopInterpretation(msg, pos)
261+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
262+
if (ctx.settings.XprintSuspension.value)
263+
report.echo(i"suspension triggered by a dependency on $sym", pos)
264+
ctx.compilationUnit.suspend() // this throws a SuspendException
265+
}
266+
267+
private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T =
268+
try thunk
269+
catch {
270+
case ex: RuntimeException =>
271+
val sw = new StringWriter()
272+
sw.write("A runtime exception occurred while executing macro expansion\n")
273+
sw.write(ex.getMessage)
274+
sw.write("\n")
275+
ex.printStackTrace(new PrintWriter(sw))
276+
sw.write("\n")
277+
throw new StopInterpretation(sw.toString, pos)
278+
case ex: InvocationTargetException =>
279+
ex.getTargetException match {
280+
case ex: scala.quoted.runtime.StopMacroExpansion =>
281+
throw ex
282+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
283+
if (ctx.settings.XprintSuspension.value)
284+
report.echo(i"suspension triggered by a dependency on $sym", pos)
285+
ctx.compilationUnit.suspend() // this throws a SuspendException
286+
case targetException =>
287+
val sw = new StringWriter()
288+
sw.write("Exception occurred while executing macro expansion.\n")
289+
if (!ctx.settings.Ydebug.value) {
290+
val end = targetException.getStackTrace.lastIndexWhere { x =>
291+
x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName
292+
}
293+
val shortStackTrace = targetException.getStackTrace.take(end + 1)
294+
targetException.setStackTrace(shortStackTrace)
295+
}
296+
targetException.printStackTrace(new PrintWriter(sw))
297+
sw.write("\n")
298+
throw new StopInterpretation(sw.toString, pos)
299+
}
300+
}
301+
302+
private object MissingClassDefinedInCurrentRun {
303+
def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = {
304+
val className = targetException.getMessage
305+
if (className eq null) None
306+
else {
307+
val sym = staticRef(className.toTypeName).symbol
308+
if (sym.isDefinedInCurrentRun) Some(sym) else None
309+
}
310+
}
311+
}
312+
313+
/** List of classes of the parameters of the signature of `sym` */
314+
private def paramsSig(sym: Symbol): List[Class[?]] = {
315+
def paramClass(param: Type): Class[?] = {
316+
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
317+
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
318+
case _ => (tpe, depth)
319+
}
320+
def javaArraySig(tpe: Type): String = {
321+
val (elemType, depth) = arrayDepth(tpe, 0)
322+
val sym = elemType.classSymbol
323+
val suffix =
324+
if (sym == defn.BooleanClass) "Z"
325+
else if (sym == defn.ByteClass) "B"
326+
else if (sym == defn.ShortClass) "S"
327+
else if (sym == defn.IntClass) "I"
328+
else if (sym == defn.LongClass) "J"
329+
else if (sym == defn.FloatClass) "F"
330+
else if (sym == defn.DoubleClass) "D"
331+
else if (sym == defn.CharClass) "C"
332+
else "L" + javaSig(elemType) + ";"
333+
("[" * depth) + suffix
334+
}
335+
def javaSig(tpe: Type): String = tpe match {
336+
case tpe: JavaArrayType => javaArraySig(tpe)
337+
case _ =>
338+
// Take the flatten name of the class and the full package name
339+
val pack = tpe.classSymbol.topLevelClass.owner
340+
val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}."
341+
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
342+
}
343+
344+
val sym = param.classSymbol
345+
if (sym == defn.BooleanClass) classOf[Boolean]
346+
else if (sym == defn.ByteClass) classOf[Byte]
347+
else if (sym == defn.CharClass) classOf[Char]
348+
else if (sym == defn.ShortClass) classOf[Short]
349+
else if (sym == defn.IntClass) classOf[Int]
350+
else if (sym == defn.LongClass) classOf[Long]
351+
else if (sym == defn.FloatClass) classOf[Float]
352+
else if (sym == defn.DoubleClass) classOf[Double]
353+
else java.lang.Class.forName(javaSig(param), false, classLoader)
354+
}
355+
def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
356+
case tp: AppliedType if defn.isContextFunctionType(tp) =>
357+
// Call context function type direct method
358+
tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
359+
case _ => Nil
360+
}
361+
val extraParams = getExtraParams(sym.info.finalResultType)
362+
val allParams = TypeErasure.erasure(sym.info) match {
363+
case meth: MethodType => meth.paramInfos ::: extraParams
364+
case _ => extraParams
365+
}
366+
allParams.map(paramClass)
367+
}
368+
}
369+
370+
object Interpreter:
371+
/** Exception that stops interpretation if some issue is found */
372+
class StopInterpretation(val msg: String, val pos: SrcPos) extends Exception
373+
374+
object Call {
375+
import tpd._
376+
/** Matches an expression that is either a field access or an application
377+
* It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
378+
*/
379+
def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] =
380+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
381+
382+
private object Call0 {
383+
def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match {
384+
case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) =>
385+
Some((fn, args))
386+
case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil))
387+
case fn: Select => Some((fn, Nil))
388+
case Apply(f @ Call0(fn, args1), args2) =>
389+
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
390+
else Some((fn, args2 :: args1))
391+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
392+
case _ => None
393+
}
394+
}
395+
}

0 commit comments

Comments
 (0)