Skip to content

Commit 4f1ad71

Browse files
committed
Harden REPL in presence of values that fail to initialize
The right hand side of value definitions in the REPL are computed in the static initializer for the wrapper object created for that input line (e.g. rs$line$1). If any of these definitions throws an exception, the wrapper class will fail to initialize, and further attempts to use the class will throw NoClassDefFoundError. In this commit, we avoid all reflective access on a wrapper class once we notice that it failed to initialize, and mark that wrapper object as invalid in the REPL state. We discard all input from the failed wrapper (which may have been multi-line containing many statements and definitions); any types, terms, aliases, or imports defined there will not override any existing with the same name, and will not be accessible in subsequent runs. Fixes #4416 Fixes #14473
1 parent 3d06d94 commit 4f1ad71

File tree

4 files changed

+139
-21
lines changed

4 files changed

+139
-21
lines changed

compiler/src/dotty/tools/repl/Rendering.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
129129
infoDiagnostic(d.symbol.showUser, d)
130130

131131
/** Render value definition result */
132-
def renderVal(d: Denotation)(using Context): Option[Diagnostic] =
132+
def renderVal(d: Denotation)(using Context): Either[InvocationTargetException, Option[Diagnostic]] =
133133
val dcl = d.symbol.showUser
134134
def msg(s: String) = infoDiagnostic(s, d)
135135
try
136-
if (d.symbol.is(Flags.Lazy)) Some(msg(dcl))
137-
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
138-
catch case e: InvocationTargetException => Some(msg(renderError(e, d)))
136+
Right(
137+
if d.symbol.is(Flags.Lazy) then Some(msg(dcl))
138+
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
139+
)
140+
catch case e: InvocationTargetException => Left(e)
139141
end renderVal
140142

141143
/** Force module initialization in the absence of members. */
@@ -144,10 +146,10 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
144146
val objectName = sym.fullName.encode.toString
145147
Class.forName(objectName, true, classLoader())
146148
Nil
147-
try load() catch case e: ExceptionInInitializerError => List(infoDiagnostic(renderError(e, sym.denot), sym.denot))
149+
try load() catch case e: ExceptionInInitializerError => List(renderError(e, sym.denot))
148150

149151
/** Render the stack trace of the underlying exception. */
150-
private def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): String =
152+
def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): Diagnostic =
151153
import dotty.tools.dotc.util.StackTraceOps._
152154
val cause = ite.getCause match
153155
case e: ExceptionInInitializerError => e.getCause
@@ -159,7 +161,7 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
159161
ste.getClassName.startsWith(REPL_WRAPPER_NAME_PREFIX) // d.symbol.owner.name.show is simple name
160162
&& (ste.getMethodName == nme.STATIC_CONSTRUCTOR.show || ste.getMethodName == nme.CONSTRUCTOR.show)
161163

162-
cause.formatStackTracePrefix(!isWrapperInitialization(_))
164+
infoDiagnostic(cause.formatStackTracePrefix(!isWrapperInitialization(_)), d)
163165
end renderError
164166

165167
private def infoDiagnostic(msg: String, d: Denotation)(using Context): Diagnostic =

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ReplCompiler extends Compiler {
6161
val rootCtx = super.rootContext.fresh
6262
.setOwner(defn.EmptyPackageClass)
6363
.withRootImports
64-
(1 to state.objectIndex).foldLeft(rootCtx)((ctx, id) =>
64+
(state.validObjectIndexes).foldLeft(rootCtx)((ctx, id) =>
6565
importPreviousRun(id)(using ctx))
6666
}
6767
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import dotty.tools.runner.ScalaClassLoader.*
3535
import org.jline.reader._
3636

3737
import scala.annotation.tailrec
38+
import scala.collection.mutable
3839
import scala.collection.JavaConverters._
3940
import scala.util.Using
4041

@@ -55,12 +56,15 @@ import scala.util.Using
5556
* @param objectIndex the index of the next wrapper
5657
* @param valIndex the index of next value binding for free expressions
5758
* @param imports a map from object index to the list of user defined imports
59+
* @param invalidObjectIndexes the set of object indexes that failed to initialize
5860
* @param context the latest compiler context
5961
*/
6062
case class State(objectIndex: Int,
6163
valIndex: Int,
6264
imports: Map[Int, List[tpd.Import]],
63-
context: Context)
65+
invalidObjectIndexes: Set[Int],
66+
context: Context):
67+
def validObjectIndexes = (1 to objectIndex).filterNot(invalidObjectIndexes.contains(_))
6468

6569
/** Main REPL instance, orchestrating input, compilation and presentation */
6670
class ReplDriver(settings: Array[String],
@@ -94,7 +98,7 @@ class ReplDriver(settings: Array[String],
9498
}
9599

96100
/** the initial, empty state of the REPL session */
97-
final def initialState: State = State(0, 0, Map.empty, rootCtx)
101+
final def initialState: State = State(0, 0, Map.empty, Set.empty, rootCtx)
98102

99103
/** Reset state of repl to the initial state
100104
*
@@ -237,7 +241,7 @@ class ReplDriver(settings: Array[String],
237241
completions.map(_.label).distinct.map(makeCandidate)
238242
}
239243
.getOrElse(Nil)
240-
end completions
244+
end completions
241245

242246
private def interpret(res: ParseResult)(implicit state: State): State = {
243247
res match {
@@ -353,14 +357,33 @@ class ReplDriver(settings: Array[String],
353357
val typeAliases =
354358
info.bounds.hi.typeMembers.filter(_.symbol.info.isTypeAlias)
355359

356-
val formattedMembers =
357-
typeAliases.map(rendering.renderTypeAlias) ++
358-
defs.map(rendering.renderMethod) ++
359-
vals.flatMap(rendering.renderVal)
360-
361-
val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers
362-
363-
(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
360+
// The wrapper object may fail to initialize if the rhs of a ValDef throws.
361+
// In that case, don't attempt to render any subsequent vals, and mark this
362+
// wrapper object index as invalid.
363+
var failedInit = false
364+
val renderedVals =
365+
val buf = mutable.ListBuffer[Diagnostic]()
366+
for d <- vals do if !failedInit then rendering.renderVal(d) match
367+
case Right(Some(v)) =>
368+
buf += v
369+
case Left(e) =>
370+
buf += rendering.renderError(e, d)
371+
failedInit = true
372+
case _ =>
373+
buf.toList
374+
375+
if failedInit then
376+
// We limit the returned diagnostics here to `renderedVals`, which will contain the rendered error
377+
// for the val which failed to initialize. Since any other defs, aliases, imports, etc. from this
378+
// input line will be inaccessible, we avoid rendering those so as not to confuse the user.
379+
(state.copy(invalidObjectIndexes = state.invalidObjectIndexes + state.objectIndex), renderedVals)
380+
else
381+
val formattedMembers =
382+
typeAliases.map(rendering.renderTypeAlias)
383+
++ defs.map(rendering.renderMethod)
384+
++ renderedVals
385+
val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers
386+
(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
364387
}
365388
else (state, Seq.empty)
366389

@@ -378,8 +401,10 @@ class ReplDriver(settings: Array[String],
378401
tree.symbol.info.memberClasses
379402
.find(_.symbol.name == newestWrapper.moduleClassName)
380403
.map { wrapperModule =>
381-
val formattedTypeDefs = typeDefs(wrapperModule.symbol)
382404
val (newState, formattedMembers) = extractAndFormatMembers(wrapperModule.symbol)
405+
val formattedTypeDefs = // don't render type defs if wrapper initialization failed
406+
if newState.invalidObjectIndexes.contains(state.objectIndex) then Seq.empty
407+
else typeDefs(wrapperModule.symbol)
383408
val highlighted = (formattedTypeDefs ++ formattedMembers)
384409
.map(d => new Diagnostic(d.msg.mapMsg(SyntaxHighlighting.highlight), d.pos, d.level))
385410
(newState, highlighted)
@@ -420,7 +445,7 @@ class ReplDriver(settings: Array[String],
420445

421446
case Imports =>
422447
for {
423-
objectIndex <- 1 to state.objectIndex
448+
objectIndex <- state.validObjectIndexes
424449
imp <- state.imports.getOrElse(objectIndex, Nil)
425450
} out.println(imp.show(using state.context))
426451
state

compiler/test/dotty/tools/repl/ReplCompilerTests.scala

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,97 @@ class ReplCompilerTests extends ReplTest:
243243
assertEquals(List("// defined class C"), lines())
244244
}
245245

246+
def assertNotFoundError(id: String): Unit =
247+
val lines = storedOutput().linesIterator
248+
assert(lines.next().startsWith("-- [E006] Not Found Error:"))
249+
assert(lines.drop(2).next().trim().endsWith(s"Not found: $id"))
250+
251+
@Test def i4416 = initially {
252+
val state = run("val x = 1 / 0")
253+
val all = lines()
254+
assertEquals(2, all.length)
255+
assert(all.head.startsWith("java.lang.ArithmeticException:"))
256+
state
257+
} andThen {
258+
val state = run("def foo = x")
259+
assertNotFoundError("x")
260+
state
261+
} andThen {
262+
run("x")
263+
assertNotFoundError("x")
264+
}
265+
266+
@Test def i4416b = initially {
267+
val state = run("val a = 1234")
268+
val _ = storedOutput() // discard output
269+
state
270+
} andThen {
271+
val state = run("val a = 1; val x = ???; val y = x")
272+
val all = lines()
273+
assertEquals(3, all.length)
274+
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
275+
state
276+
} andThen {
277+
val state = run("x")
278+
assertNotFoundError("x")
279+
state
280+
} andThen {
281+
val state = run("y")
282+
assertNotFoundError("y")
283+
state
284+
} andThen {
285+
run("a") // `a` should retain its original binding
286+
assertEquals("val res0: Int = 1234", storedOutput().trim)
287+
}
288+
289+
@Test def i4416_imports = initially {
290+
run("import scala.collection.mutable")
291+
} andThen {
292+
val state = run("import scala.util.Try; val x = ???")
293+
val _ = storedOutput() // discard output
294+
state
295+
} andThen {
296+
run(":imports") // scala.util.Try should not be imported
297+
assertEquals("import scala.collection.mutable", storedOutput().trim)
298+
}
299+
300+
@Test def i4416_types_defs_aliases = initially {
301+
val state =
302+
run("""|type Foo = String
303+
|trait Bar
304+
|def bar: Bar = ???
305+
|val x = ???
306+
|""".stripMargin)
307+
val all = lines()
308+
assertEquals(3, all.length)
309+
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
310+
assert("type alias in failed wrapper should not be rendered",
311+
!all.exists(_.startsWith("// defined alias type Foo = String")))
312+
assert("type definitions in failed wrapper should not be rendered",
313+
!all.exists(_.startsWith("// defined trait Bar")))
314+
assert("defs in failed wrapper should not be rendered",
315+
!all.exists(_.startsWith("def bar: Bar")))
316+
state
317+
} andThen {
318+
val state = run("def foo: Foo = ???")
319+
assertNotFoundError("type Foo")
320+
state
321+
} andThen {
322+
val state = run("type B = Bar")
323+
assertNotFoundError("type Bar")
324+
state
325+
} andThen {
326+
run("bar")
327+
assertNotFoundError("bar")
328+
}
329+
330+
@Test def i14473 = initially {
331+
run("""val (x,y) = if true then "hi" else (42,17)""")
332+
val all = lines()
333+
assertEquals(2, all.length)
334+
assertEquals("scala.MatchError: hi (of class java.lang.String)", all.head)
335+
}
336+
246337
@Test def i14491 =
247338
initially {
248339
run("import language.experimental.fewerBraces")

0 commit comments

Comments
 (0)