diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 5510a65b24ff..5bd039fb2629 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -11,8 +11,7 @@ import dotty.tools.dotc.core.Decorators._ import dotty.tools.dotc.core.Flags._ import dotty.tools.dotc.core.NameKinds.FlatName import dotty.tools.dotc.core.Names.{Name, TermName} -import dotty.tools.dotc.core.StdNames.nme -import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD +import dotty.tools.dotc.core.StdNames._ import dotty.tools.dotc.core.quoted._ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.core.Symbols._ @@ -22,6 +21,7 @@ import dotty.tools.dotc.tastyreflect.ReflectionImpl import scala.util.control.NonFatal import dotty.tools.dotc.util.SourcePosition +import dotty.tools.repl.AbstractFileClassLoader import scala.reflect.ClassTag @@ -113,14 +113,22 @@ object Splicer { } protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = { - val instance = loadModule(moduleClass) + val (instance, clazz) = + if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) { + (null, loadReplLineClass(moduleClass)) + } else { + val instance = loadModule(moduleClass) + (instance, instance.getClass) + } + def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match { case tp: AppliedType if defn.isImplicitFunctionType(tp) => getDirectName(tp.args.last, NameKinds.DirectMethodName(name)) case _ => name } + val name = getDirectName(fn.info.finalResultType, fn.name.asTermName) - val method = getMethod(instance.getClass, name, paramsSig(fn)) + val method = getMethod(clazz, name, paramsSig(fn)) stopIfRuntimeException(method.invoke(instance, args: _*)) } @@ -140,7 +148,7 @@ object Splicer { if (sym.owner.is(Package)) { // is top level object val moduleClass = loadClass(sym.fullName) - moduleClass.getField(MODULE_INSTANCE_FIELD).get(null) + moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null) } else { // nested object in an object val clazz = loadClass(sym.fullNameSeparated(FlatName)) @@ -148,6 +156,11 @@ object Splicer { } } + private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[_] = { + val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader) + lineClassloader.loadClass(moduleClass.name.firstPart.toString) + } + private def loadClass(name: Name): Class[_] = { try classLoader.loadClass(name.toString) catch { diff --git a/compiler/test-resources/repl/i5551 b/compiler/test-resources/repl/i5551 new file mode 100644 index 000000000000..4e5a0a186a5e --- /dev/null +++ b/compiler/test-resources/repl/i5551 @@ -0,0 +1,12 @@ +scala> import scala.quoted._ + +scala> def assertImpl(expr: Expr[Boolean]) = '{ if !(~expr) then throw new AssertionError("failed assertion")} +def assertImpl(expr: quoted.Expr[Boolean]): quoted.Expr[Unit] + +scala> inline def assert(expr: => Boolean): Unit = ~ assertImpl('(expr)) +def assert(expr: => Boolean): Unit + +scala> assert(0 == 0) + +scala> try assert(0 == 1) catch { case _: AssertionError => println("ok") } +ok