Skip to content

Commit 05303a9

Browse files
committed
Refactor interpreter
Extract the general purpose logic from the splice interpreter. The new interpreter class will be the basis for the macro annotation interpreter. The splice interpreter only keeps logic related with level -1 quote and type evaluation. Part of https://github.com/dotty-staging/dotty/tree/design-macro-annotations
1 parent bf808b3 commit 05303a9

File tree

2 files changed

+383
-333
lines changed

2 files changed

+383
-333
lines changed
Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
package dotty.tools.dotc
2+
package quoted
3+
4+
import scala.language.unsafeNulls
5+
6+
import scala.collection.mutable
7+
import scala.reflect.ClassTag
8+
9+
import java.io.{PrintWriter, StringWriter}
10+
import java.lang.reflect.{InvocationTargetException, Method => JLRMethod}
11+
12+
import dotty.tools.dotc.ast.tpd
13+
import dotty.tools.dotc.ast.TreeMapWithImplicits
14+
import dotty.tools.dotc.core.Annotations._
15+
import dotty.tools.dotc.core.Constants._
16+
import dotty.tools.dotc.core.Contexts._
17+
import dotty.tools.dotc.core.Decorators._
18+
import dotty.tools.dotc.core.Denotations.staticRef
19+
import dotty.tools.dotc.core.Flags._
20+
import dotty.tools.dotc.core.NameKinds.FlatName
21+
import dotty.tools.dotc.core.Names._
22+
import dotty.tools.dotc.core.StagingContext._
23+
import dotty.tools.dotc.core.StdNames._
24+
import dotty.tools.dotc.core.Symbols._
25+
import dotty.tools.dotc.core.TypeErasure
26+
import dotty.tools.dotc.core.Types._
27+
import dotty.tools.dotc.quoted._
28+
import dotty.tools.dotc.transform.TreeMapWithStages._
29+
import dotty.tools.dotc.typer.ImportInfo.withRootImports
30+
import dotty.tools.dotc.util.SrcPos
31+
import dotty.tools.repl.AbstractFileClassLoader
32+
33+
34+
/** List of classes of the parameters of the signature of `sym` */
35+
abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context):
36+
import Interpreter._
37+
import tpd._
38+
39+
type Env = Map[Symbol, Object]
40+
41+
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
42+
* Return Some of the result or None if some error happen during the interpretation.
43+
*/
44+
final def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] =
45+
interpretTree(tree)(Map.empty) match {
46+
case obj: T => Some(obj)
47+
case obj =>
48+
// TODO upgrade to a full type tag check or something similar
49+
report.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
50+
None
51+
}
52+
53+
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments. */
54+
protected def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
55+
case Literal(Constant(value)) =>
56+
interpretLiteral(value)
57+
58+
case tree: Ident if tree.symbol.is(Inline, butNot = Method) =>
59+
tree.tpe.widenTermRefExpr match
60+
case ConstantType(c) => c.value.asInstanceOf[Object]
61+
case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos)
62+
63+
// TODO disallow interpreted method calls as arguments
64+
case Call(fn, args) =>
65+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package))
66+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
67+
else if (fn.symbol.is(Module))
68+
interpretModuleAccess(fn.symbol)
69+
else if (fn.symbol.is(Method) && fn.symbol.isStatic) {
70+
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
71+
staticMethodCall(interpretArgs(args, fn.symbol.info))
72+
}
73+
else if fn.symbol.isStatic then
74+
assert(args.isEmpty)
75+
interpretedStaticFieldAccess(fn.symbol)
76+
else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic)
77+
if (fn.name == nme.asInstanceOfPM)
78+
interpretModuleAccess(fn.qualifier.symbol)
79+
else {
80+
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
81+
staticMethodCall(interpretArgs(args, fn.symbol.info))
82+
}
83+
else if (env.contains(fn.symbol))
84+
env(fn.symbol)
85+
else if (tree.symbol.is(InlineProxy))
86+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
87+
else
88+
unexpectedTree(tree)
89+
90+
case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) =>
91+
(obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj))
92+
93+
// Interpret `foo(j = x, i = y)` which it is expanded to
94+
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
95+
case Block(stats, expr) => interpretBlock(stats, expr)
96+
case NamedArg(_, arg) => interpretTree(arg)
97+
98+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
99+
100+
case Typed(expr, _) =>
101+
interpretTree(expr)
102+
103+
case SeqLiteral(elems, _) =>
104+
interpretVarargs(elems.map(e => interpretTree(e)))
105+
106+
case _ =>
107+
unexpectedTree(tree)
108+
}
109+
110+
private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = {
111+
def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] =
112+
assert(args.size == argTypes.size)
113+
val view =
114+
for (arg, info) <- args.lazyZip(argTypes) yield
115+
info match
116+
case _: ExprType => () => interpretTree(arg) // by-name argument
117+
case _ => interpretTree(arg) // by-value argument
118+
view.toList
119+
120+
fnType.dealias match
121+
case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType)
122+
case fnType: MethodType =>
123+
val argTypes = fnType.paramInfos
124+
assert(argss.head.size == argTypes.size)
125+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType)
126+
case fnType: AppliedType if defn.isContextFunctionType(fnType) =>
127+
val argTypes :+ resType = fnType.args: @unchecked
128+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType)
129+
case fnType: PolyType => interpretArgs(argss, fnType.resType)
130+
case fnType: ExprType => interpretArgs(argss, fnType.resType)
131+
case _ =>
132+
assert(argss.isEmpty)
133+
Nil
134+
}
135+
136+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
137+
var unexpected: Option[Object] = None
138+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
139+
case stat: ValDef =>
140+
accEnv.updated(stat.symbol, interpretTree(stat.rhs)(accEnv))
141+
case stat =>
142+
if (unexpected.isEmpty)
143+
unexpected = Some(unexpectedTree(stat))
144+
accEnv
145+
})
146+
unexpected.getOrElse(interpretTree(expr)(newEnv))
147+
}
148+
149+
private def interpretLiteral(value: Any)(implicit env: Env): Object =
150+
value.asInstanceOf[Object]
151+
152+
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
153+
args.toSeq
154+
155+
private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = {
156+
val (inst, clazz) =
157+
try
158+
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE))
159+
(null, loadReplLineClass(moduleClass))
160+
else {
161+
val inst = loadModule(moduleClass)
162+
(inst, inst.getClass)
163+
}
164+
catch
165+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
166+
if (ctx.settings.XprintSuspension.value)
167+
report.echo(i"suspension triggered by a dependency on $sym", pos)
168+
ctx.compilationUnit.suspend() // this throws a SuspendException
169+
170+
val name = fn.name.asTermName
171+
val method = getMethod(clazz, name, paramsSig(fn))
172+
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method)
173+
}
174+
175+
private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = {
176+
val clazz = loadClass(sym.owner.fullName.toString)
177+
val field = clazz.getField(sym.name.toString)
178+
field.get(null)
179+
}
180+
181+
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
182+
loadModule(fn.moduleClass)
183+
184+
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
185+
val clazz = loadClass(fn.owner.fullName.toString)
186+
val constr = clazz.getConstructor(paramsSig(fn): _*)
187+
constr.newInstance(args: _*).asInstanceOf[Object]
188+
}
189+
190+
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
191+
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.srcPos)
192+
193+
private def loadModule(sym: Symbol): Object =
194+
if (sym.owner.is(Package)) {
195+
// is top level object
196+
val moduleClass = loadClass(sym.fullName.toString)
197+
moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null)
198+
}
199+
else {
200+
// nested object in an object
201+
val className = {
202+
val pack = sym.topLevelClass.owner
203+
if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString
204+
else pack.showFullName + "." + sym.flatName
205+
}
206+
val clazz = loadClass(className)
207+
clazz.getConstructor().newInstance().asInstanceOf[Object]
208+
}
209+
210+
private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[?] = {
211+
val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader)
212+
lineClassloader.loadClass(moduleClass.name.firstPart.toString)
213+
}
214+
215+
private def loadClass(name: String): Class[?] =
216+
try classLoader.loadClass(name)
217+
catch {
218+
case _: ClassNotFoundException if ctx.compilationUnit.isSuspendable =>
219+
if (ctx.settings.XprintSuspension.value)
220+
report.echo(i"suspension triggered by a dependency on $name", pos)
221+
ctx.compilationUnit.suspend()
222+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
223+
if (ctx.settings.XprintSuspension.value)
224+
report.echo(i"suspension triggered by a dependency on $sym", pos)
225+
ctx.compilationUnit.suspend() // this throws a SuspendException
226+
}
227+
228+
private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod =
229+
try clazz.getMethod(name.toString, paramClasses: _*)
230+
catch {
231+
case _: NoSuchMethodException =>
232+
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)"
233+
throw new StopInterpretation(msg, pos)
234+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
235+
if (ctx.settings.XprintSuspension.value)
236+
report.echo(i"suspension triggered by a dependency on $sym", pos)
237+
ctx.compilationUnit.suspend() // this throws a SuspendException
238+
}
239+
240+
private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T =
241+
try thunk
242+
catch {
243+
case ex: RuntimeException =>
244+
val sw = new StringWriter()
245+
sw.write("A runtime exception occurred while executing macro expansion\n")
246+
sw.write(ex.getMessage)
247+
sw.write("\n")
248+
ex.printStackTrace(new PrintWriter(sw))
249+
sw.write("\n")
250+
throw new StopInterpretation(sw.toString, pos)
251+
case ex: InvocationTargetException =>
252+
ex.getTargetException match {
253+
case ex: scala.quoted.runtime.StopMacroExpansion =>
254+
throw ex
255+
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
256+
if (ctx.settings.XprintSuspension.value)
257+
report.echo(i"suspension triggered by a dependency on $sym", pos)
258+
ctx.compilationUnit.suspend() // this throws a SuspendException
259+
case targetException =>
260+
val sw = new StringWriter()
261+
sw.write("Exception occurred while executing macro expansion.\n")
262+
if (!ctx.settings.Ydebug.value) {
263+
val end = targetException.getStackTrace.lastIndexWhere { x =>
264+
x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName
265+
}
266+
val shortStackTrace = targetException.getStackTrace.take(end + 1)
267+
targetException.setStackTrace(shortStackTrace)
268+
}
269+
targetException.printStackTrace(new PrintWriter(sw))
270+
sw.write("\n")
271+
throw new StopInterpretation(sw.toString, pos)
272+
}
273+
}
274+
275+
private object MissingClassDefinedInCurrentRun {
276+
def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = {
277+
val className = targetException.getMessage
278+
if (className eq null) None
279+
else {
280+
val sym = staticRef(className.toTypeName).symbol
281+
if (sym.isDefinedInCurrentRun) Some(sym) else None
282+
}
283+
}
284+
}
285+
286+
/** List of classes of the parameters of the signature of `sym` */
287+
private def paramsSig(sym: Symbol): List[Class[?]] = {
288+
def paramClass(param: Type): Class[?] = {
289+
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
290+
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
291+
case _ => (tpe, depth)
292+
}
293+
def javaArraySig(tpe: Type): String = {
294+
val (elemType, depth) = arrayDepth(tpe, 0)
295+
val sym = elemType.classSymbol
296+
val suffix =
297+
if (sym == defn.BooleanClass) "Z"
298+
else if (sym == defn.ByteClass) "B"
299+
else if (sym == defn.ShortClass) "S"
300+
else if (sym == defn.IntClass) "I"
301+
else if (sym == defn.LongClass) "J"
302+
else if (sym == defn.FloatClass) "F"
303+
else if (sym == defn.DoubleClass) "D"
304+
else if (sym == defn.CharClass) "C"
305+
else "L" + javaSig(elemType) + ";"
306+
("[" * depth) + suffix
307+
}
308+
def javaSig(tpe: Type): String = tpe match {
309+
case tpe: JavaArrayType => javaArraySig(tpe)
310+
case _ =>
311+
// Take the flatten name of the class and the full package name
312+
val pack = tpe.classSymbol.topLevelClass.owner
313+
val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}."
314+
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
315+
}
316+
317+
val sym = param.classSymbol
318+
if (sym == defn.BooleanClass) classOf[Boolean]
319+
else if (sym == defn.ByteClass) classOf[Byte]
320+
else if (sym == defn.CharClass) classOf[Char]
321+
else if (sym == defn.ShortClass) classOf[Short]
322+
else if (sym == defn.IntClass) classOf[Int]
323+
else if (sym == defn.LongClass) classOf[Long]
324+
else if (sym == defn.FloatClass) classOf[Float]
325+
else if (sym == defn.DoubleClass) classOf[Double]
326+
else java.lang.Class.forName(javaSig(param), false, classLoader)
327+
}
328+
def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
329+
case tp: AppliedType if defn.isContextFunctionType(tp) =>
330+
// Call context function type direct method
331+
tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
332+
case _ => Nil
333+
}
334+
val extraParams = getExtraParams(sym.info.finalResultType)
335+
val allParams = TypeErasure.erasure(sym.info) match {
336+
case meth: MethodType => meth.paramInfos ::: extraParams
337+
case _ => extraParams
338+
}
339+
allParams.map(paramClass)
340+
}
341+
end Interpreter
342+
343+
object Interpreter:
344+
/** Exception that stops interpretation if some issue is found */
345+
class StopInterpretation(val msg: String, val pos: SrcPos) extends Exception
346+
347+
object Call:
348+
import tpd._
349+
/** Matches an expression that is either a field access or an application
350+
* It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
351+
*/
352+
def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] =
353+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
354+
355+
private object Call0 {
356+
def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match {
357+
case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) =>
358+
Some((fn, args))
359+
case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil))
360+
case fn: Select => Some((fn, Nil))
361+
case Apply(f @ Call0(fn, args1), args2) =>
362+
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
363+
else Some((fn, args2 :: args1))
364+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
365+
case _ => None
366+
}
367+
}
368+
end Call

0 commit comments

Comments
 (0)