Skip to content

Commit 1736fb9

Browse files
committed
Restore state if errors are thrown away
1 parent f088428 commit 1736fb9

File tree

1 file changed

+68
-42
lines changed

1 file changed

+68
-42
lines changed

compiler/src/dotty/tools/dotc/transform/init/Semantic.scala

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -296,28 +296,28 @@ object Semantic:
296296
case None => stable.get(value, expr)
297297
case res => res
298298

299-
/** Conditionally perform an operation
299+
/** Backup the state of the cache
300300
*
301-
* If the operation returns true, the changes are commited. Otherwise, the changes are reverted.
301+
* All the shared data structures must be immutable.
302302
*/
303-
def conditionally[T](fn: => (Boolean, T)): T =
304-
val last2 = this.last
305-
val current2 = this.current
306-
val stable2 = this.stable
307-
val heap2 = this.heap
308-
val heapStable2 = this.heapStable
309-
val changed2 = this.changed
310-
val (commit, value) = fn
311-
312-
if commit then
313-
this.last = last2
314-
this.current = current2
315-
this.stable = stable2
316-
this.heap = heap2
317-
this.heapStable = heapStable2
318-
this.changed = changed2
319-
320-
value
303+
def backup(): Cache =
304+
val cache = new Cache
305+
cache.last = this.last
306+
cache.current = this.current
307+
cache.stable = this.stable
308+
cache.heap = this.heap
309+
cache.heapStable = this.heapStable
310+
cache.changed = this.changed
311+
cache
312+
313+
/** Restore state from a backup */
314+
def restore(cache: Cache) =
315+
this.last = cache.last
316+
this.current = cache.current
317+
this.stable = cache.stable
318+
this.heap = cache.heap
319+
this.heapStable = cache.heapStable
320+
this.changed = cache.changed
321321

322322
/** Copy the value of `(value, expr)` from the last cache to the current cache
323323
*
@@ -459,21 +459,36 @@ object Semantic:
459459
def report(err: Error): Unit
460460
def reportAll(errs: Seq[Error]): Unit = for err <- errs do report(err)
461461

462+
/** A TryReporter cannot be simply thrown away
463+
*
464+
* Either `abort` should be called or the errors be reported.
465+
*/
466+
trait TryReporter extends Reporter:
467+
def abort()(using Cache): Unit
468+
def errors: List[Error]
469+
462470
object Reporter:
463471
class BufferedReporter extends Reporter:
464472
private val buf = new mutable.ArrayBuffer[Error]
465473
def errors = buf.toList
466474
def report(err: Error) = buf += err
467475

476+
class TryBufferedReporter(backup: Cache) extends BufferedReporter with TryReporter:
477+
def abort()(using Cache): Unit = cache.restore(backup)
478+
468479
class ErrorFound(val error: Error) extends Exception
469480
class StopEarlyReporter extends Reporter:
470481
def report(err: Error) = throw new ErrorFound(err)
471482

472-
/** Capture all errors and return as a list */
473-
def errorsIn(fn: Reporter ?=> Unit): List[Error] =
474-
val reporter = new BufferedReporter
483+
/** Capture all errors with a TryReporter
484+
*
485+
* The TryReporter cannot be thrown away: either `abort` must be called or
486+
* the errors must be reported.
487+
*/
488+
def errorsIn(fn: Reporter ?=> Unit)(using Cache): TryReporter =
489+
val reporter = new TryBufferedReporter(cache.backup())
475490
fn(using reporter)
476-
reporter.errors.toList
491+
reporter
477492

478493
/** Stop on first error */
479494
def stopEarly(fn: Reporter ?=> Unit): List[Error] =
@@ -485,6 +500,11 @@ object Semantic:
485500
catch case ex: ErrorFound =>
486501
ex.error :: Nil
487502

503+
def hasErrors(fn: Reporter ?=> Unit)(using Cache): Boolean =
504+
val backup = cache.backup()
505+
val errors = stopEarly(fn)
506+
cache.restore(backup)
507+
errors.nonEmpty
488508

489509
inline def reporter(using r: Reporter): Reporter = r
490510

@@ -517,9 +537,8 @@ object Semantic:
517537
def widenArg: Contextual[Value] =
518538
a match
519539
case _: Ref | _: Fun =>
520-
val errors = Reporter.stopEarly { a.promote("Argument cannot be promoted to hot") }
521-
if errors.isEmpty then Hot
522-
else Cold
540+
val hasError = Reporter.hasErrors { a.promote("Argument cannot be promoted to hot") }
541+
if hasError then Cold else Hot
523542

524543
case RefSet(refs) =>
525544
refs.map(_.widenArg).join
@@ -662,9 +681,11 @@ object Semantic:
662681
var allArgsHot = true
663682
val allParamTypes = methodType.paramInfoss.flatten.map(_.repeatedToSingle)
664683
val errors = allParamTypes.zip(args).flatMap { (info, arg) =>
665-
val errors = Reporter.errorsIn { arg.promote }
666-
allArgsHot = allArgsHot && errors.isEmpty
667-
info match
684+
val tryReporter = Reporter.errorsIn { arg.promote }
685+
allArgsHot = allArgsHot && tryReporter.errors.isEmpty
686+
if tryReporter.errors.isEmpty then tryReporter.errors
687+
else
688+
info match
668689
case typeParamRef: TypeParamRef =>
669690
val bounds = typeParamRef.underlying.bounds
670691
val isWithinBounds = bounds.lo <:< defn.NothingType && defn.AnyType <:< bounds.hi
@@ -673,8 +694,12 @@ object Semantic:
673694
// type parameter T with Any as its upper bound and Nothing as its lower bound.
674695
// the other arguments should either correspond to a parameter type that is T
675696
// or that does not contain T as a component.
676-
if isWithinBounds && !otherParamContains then Nil else errors
677-
case _ => errors
697+
if isWithinBounds && !otherParamContains then
698+
tryReporter.abort()
699+
Nil
700+
else
701+
tryReporter.errors
702+
case _ => tryReporter.errors
678703
}
679704
(errors, allArgsHot)
680705

@@ -721,15 +746,16 @@ object Semantic:
721746
if target.hasSource then
722747
val cls = target.owner.enclosingClass.asClass
723748
val ddef = target.defTree.asInstanceOf[DefDef]
724-
val argErrors = Reporter.errorsIn { promoteArgs() }
749+
val tryReporter = Reporter.errorsIn { promoteArgs() }
725750
// normal method call
726-
if argErrors.nonEmpty && isSyntheticApply(meth) then
751+
if tryReporter.errors.nonEmpty && isSyntheticApply(meth) then
752+
tryReporter.abort()
727753
val klass = meth.owner.companionClass.asClass
728754
val outerCls = klass.owner.lexicallyEnclosingClass.asClass
729755
val outer = resolveOuterSelect(outerCls, ref, 1)
730756
outer.instantiate(klass, klass.primaryConstructor, args)
731757
else
732-
reporter.reportAll(argErrors)
758+
reporter.reportAll(tryReporter.errors)
733759
extendTrace(ddef) {
734760
eval(ddef.rhs, ref, cls, cacheResult = true)
735761
}
@@ -841,15 +867,15 @@ object Semantic:
841867
if promoted.isCurrentObjectPromoted then Hot
842868
else value match {
843869
case Hot =>
844-
val buffer = new mutable.ArrayBuffer[Error]
870+
var allHot = true
845871
val args2 = args.map { arg =>
846-
val errors = Reporter.errorsIn { arg.promote }
847-
buffer ++= errors
848-
if errors.isEmpty then Hot
849-
else arg.value.widenArg
872+
val hasErrors = Reporter.hasErrors { arg.promote }
873+
allHot = allHot && !hasErrors
874+
if hasErrors then arg.value.widenArg
875+
else Hot
850876
}
851877

852-
if buffer.isEmpty then
878+
if allHot then
853879
Hot
854880
else
855881
val outer = Hot
@@ -998,7 +1024,7 @@ object Semantic:
9981024
given Trace = Trace.empty.add(body)
9991025
res.promote("The function return value is not fully initialized.")
10001026
}
1001-
if (errors.nonEmpty)
1027+
if errors.nonEmpty then
10021028
reporter.report(UnsafePromotion(msg, trace.toVector, errors.head))
10031029
else
10041030
promoted.add(fun)

0 commit comments

Comments
 (0)