Skip to content

Commit ee0dd7a

Browse files
authored
Fix scala#21619: Refactor NotNullInfo to record every reference which is retracted once. (scala#21624)
This PR improves the flow typing for returning and exceptions. The `NotNullInfo` is defined as following now: ```scala case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): ``` * `retracted` contains variable references that are ever assigned to null; * if `asserted` is not `null`, it contains `val` or `var` references that are known to be not null, after the tree finishes executing normally (non-exceptionally); * if `asserted` is `null`, the tree is know to terminate, by throwing, returning, or calling a function with `Nothing` type. Hence, it acts like a universal set. `alt` is defined as `<a1,r1>.alt(<a2,r2>) = <a1 intersect a2, r1 union r2>`. The difficult part is the `try ... catch ... finally ...`. We don't know at which point an exception is thrown in the body, and the catch cases may be not exhaustive, we have to collect any reference that is once retracted. Fix scala#21619
2 parents e6b4222 + 200c038 commit ee0dd7a

File tree

9 files changed

+257
-78
lines changed

9 files changed

+257
-78
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,13 +777,13 @@ object Contexts {
777777

778778
extension (c: Context)
779779
def addNotNullInfo(info: NotNullInfo) =
780-
c.withNotNullInfos(c.notNullInfos.extendWith(info))
780+
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c
781781

782782
def addNotNullRefs(refs: Set[TermRef]) =
783-
c.addNotNullInfo(NotNullInfo(refs, Set()))
783+
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c
784784

785785
def withNotNullInfos(infos: List[NotNullInfo]): Context =
786-
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
786+
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)
787787

788788
def relaxedOverrideContext: Context =
789789
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ trait Applications extends Compatibility {
11341134
case _ => ()
11351135
else ()
11361136

1137-
fun1.tpe match {
1137+
val result = fun1.tpe match {
11381138
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
11391139
case TryDynamicCallType =>
11401140
val isInsertedApply = fun1 match {
@@ -1208,6 +1208,11 @@ trait Applications extends Compatibility {
12081208
else tryWithImplicitOnQualifier(fun1, proto).getOrElse(fail))
12091209
}
12101210
}
1211+
1212+
if result.tpe.isNothingType then
1213+
val nnInfo = result.notNullInfo
1214+
result.withNotNullInfo(nnInfo.terminatedInfo)
1215+
else result
12111216
}
12121217

12131218
/** Convert expression like

compiler/src/dotty/tools/dotc/typer/Nullables.scala

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,34 +52,46 @@ object Nullables:
5252
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
5353
TypeBoundsTree(lo, hiTree, alias)
5454

55-
/** A set of val or var references that are known to be not null, plus a set of
56-
* variable references that are not known (anymore) to be not null
55+
/** A set of val or var references that are known to be not null
56+
* after the tree finishes executing normally (non-exceptionally),
57+
* plus a set of variable references that are ever assigned to null,
58+
* and may therefore be null if execution of the tree is interrupted
59+
* by an exception.
5760
*/
58-
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
59-
assert((asserted & retracted).isEmpty)
60-
61+
case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
6162
def isEmpty = this eq NotNullInfo.empty
6263

6364
def retractedInfo = NotNullInfo(Set(), retracted)
6465

66+
def terminatedInfo = NotNullInfo(null, retracted)
67+
6568
/** The sequential combination with another not-null info */
6669
def seq(that: NotNullInfo): NotNullInfo =
6770
if this.isEmpty then that
6871
else if that.isEmpty then this
69-
else NotNullInfo(
70-
this.asserted.union(that.asserted).diff(that.retracted),
71-
this.retracted.union(that.retracted).diff(that.asserted))
72+
else
73+
val newAsserted =
74+
if this.asserted == null || that.asserted == null then null
75+
else this.asserted.diff(that.retracted).union(that.asserted)
76+
val newRetracted = this.retracted.union(that.retracted)
77+
NotNullInfo(newAsserted, newRetracted)
7278

7379
/** The alternative path combination with another not-null info. Used to merge
74-
* the nullability info of the two branches of an if.
80+
* the nullability info of the branches of an if or match.
7581
*/
7682
def alt(that: NotNullInfo): NotNullInfo =
77-
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))
83+
val newAsserted =
84+
if this.asserted == null then that.asserted
85+
else if that.asserted == null then this.asserted
86+
else this.asserted.intersect(that.asserted)
87+
val newRetracted = this.retracted.union(that.retracted)
88+
NotNullInfo(newAsserted, newRetracted)
89+
end NotNullInfo
7890

7991
object NotNullInfo:
8092
val empty = new NotNullInfo(Set(), Set())
81-
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
82-
if asserted.isEmpty && retracted.isEmpty then empty
93+
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
94+
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
8395
else new NotNullInfo(asserted, retracted)
8496
end NotNullInfo
8597

@@ -223,7 +235,7 @@ object Nullables:
223235
*/
224236
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
225237
case info :: infos1 =>
226-
if info.asserted.contains(ref) then true
238+
if info.asserted == null || info.asserted.contains(ref) then true
227239
else if info.retracted.contains(ref) then false
228240
else infos1.impliesNotNull(ref)
229241
case _ =>
@@ -233,16 +245,15 @@ object Nullables:
233245
* or retractions in `info` supersede infos in existing entries of `infos`.
234246
*/
235247
def extendWith(info: NotNullInfo) =
236-
if info.isEmpty
237-
|| info.asserted.forall(infos.impliesNotNull(_))
238-
&& !info.retracted.exists(infos.impliesNotNull(_))
239-
then infos
248+
if info.isEmpty then infos
240249
else info :: infos
241250

242251
/** Retract all references to mutable variables */
243252
def retractMutables(using Context) =
244-
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
245-
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
253+
val mutables = infos.foldLeft(Set[TermRef]()):
254+
(ms, info) => ms.union(
255+
if info.asserted == null then Set.empty
256+
else info.asserted.filter(_.symbol.is(Mutable)))
246257
infos.extendWith(NotNullInfo(Set(), mutables))
247258

248259
end extension
@@ -304,15 +315,35 @@ object Nullables:
304315
extension (tree: Tree)
305316

306317
/* The `tree` with added nullability attachment */
307-
def withNotNullInfo(info: NotNullInfo): tree.type =
308-
if !info.isEmpty then tree.putAttachment(NNInfo, info)
318+
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
319+
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
309320
tree
310321

322+
/* Collect the nullability info from parts of `tree` */
323+
def collectNotNullInfo(using Context): NotNullInfo = tree match
324+
case Typed(expr, _) =>
325+
expr.notNullInfo
326+
case Apply(fn, args) =>
327+
val argsInfo = args.map(_.notNullInfo)
328+
val fnInfo = fn.notNullInfo
329+
argsInfo.foldLeft(fnInfo)(_ seq _)
330+
case TypeApply(fn, _) =>
331+
fn.notNullInfo
332+
case _ =>
333+
// Other cases are handled specially in typer.
334+
NotNullInfo.empty
335+
311336
/* The nullability info of `tree` */
312337
def notNullInfo(using Context): NotNullInfo =
313-
stripInlined(tree).getAttachment(NNInfo) match
314-
case Some(info) if !ctx.erasedTypes => info
315-
case _ => NotNullInfo.empty
338+
if !ctx.explicitNulls then NotNullInfo.empty
339+
else
340+
val tree1 = stripInlined(tree)
341+
tree1.getAttachment(NNInfo) match
342+
case Some(info) if !ctx.erasedTypes => info
343+
case _ =>
344+
val nnInfo = tree1.collectNotNullInfo
345+
tree1.withNotNullInfo(nnInfo)
346+
nnInfo
316347

317348
/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
318349
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =
@@ -393,21 +424,23 @@ object Nullables:
393424
end extension
394425

395426
extension (tree: Assign)
396-
def computeAssignNullable()(using Context): tree.type = tree.lhs match
397-
case TrackedRef(ref) =>
398-
val rhstp = tree.rhs.typeOpt
399-
if ctx.explicitNulls && ref.isNullableUnion then
400-
if rhstp.isNullType || rhstp.isNullableUnion then
401-
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
402-
// lhs variable is no longer trackable. We don't need to check whether the type `T`
403-
// is correct here, as typer will check it.
404-
tree.withNotNullInfo(NotNullInfo(Set(), Set(ref)))
405-
else
406-
// If the initial type is nullable and the assigned value is non-null,
407-
// we add it to the NotNull.
408-
tree.withNotNullInfo(NotNullInfo(Set(ref), Set()))
409-
else tree
410-
case _ => tree
427+
def computeAssignNullable()(using Context): tree.type =
428+
var nnInfo = tree.rhs.notNullInfo
429+
tree.lhs match
430+
case TrackedRef(ref) if ctx.explicitNulls && ref.isNullableUnion =>
431+
nnInfo = nnInfo.seq:
432+
val rhstp = tree.rhs.typeOpt
433+
if rhstp.isNullType || rhstp.isNullableUnion then
434+
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
435+
// lhs variable is no longer trackable. We don't need to check whether the type `T`
436+
// is correct here, as typer will check it.
437+
NotNullInfo(Set(), Set(ref))
438+
else
439+
// If the initial type is nullable and the assigned value is non-null,
440+
// we add it to the NotNull.
441+
NotNullInfo(Set(ref), Set())
442+
case _ =>
443+
tree.withNotNullInfo(nnInfo)
411444
end extension
412445

413446
private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!)
@@ -515,7 +548,10 @@ object Nullables:
515548
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
516549
&& ctx.notNullInfos.impliesNotNull(ref)
517550

518-
val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
551+
val retractedVars = ctx.notNullInfos.flatMap(info =>
552+
if info.asserted == null then Set.empty
553+
else info.asserted.filter(isRetracted)
554+
).toSet
519555
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
520556
end whileContext
521557

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12011201
untpd.unsplice(tree.expr).putAttachment(AscribedToUnit, ())
12021202
typed(tree.expr, underlyingTreeTpe.tpe.widenSkolem)
12031203
assignType(cpy.Typed(tree)(expr1, tpt), underlyingTreeTpe)
1204-
.withNotNullInfo(expr1.notNullInfo)
12051204
}
12061205

12071206
if (untpd.isWildcardStarArg(tree)) {
@@ -1551,11 +1550,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15511550

15521551
def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
15531552
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
1554-
result.withNotNullInfo(
1555-
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
1556-
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
1557-
else thenPathInfo.alt(elsePathInfo)
1558-
)
1553+
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
15591554
end typedIf
15601555

15611556
/** Decompose function prototype into a list of parameter prototypes and a result
@@ -2139,20 +2134,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
21392134
case1
21402135
}
21412136
.asInstanceOf[List[CaseDef]]
2142-
var nni = sel.notNullInfo
2143-
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
2144-
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt).withNotNullInfo(nni)
2137+
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
2138+
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
21452139
}
21462140

21472141
// Overridden in InlineTyper for inline matches
21482142
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
21492143
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
21502144
.asInstanceOf[List[CaseDef]]
2151-
var nni = sel.notNullInfo
2152-
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
2153-
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
2145+
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
2146+
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
21542147
}
21552148

2149+
private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
2150+
if cases.isEmpty then
2151+
// Empty cases is not allowed for match tree in the source code,
2152+
// but it can be generated by inlining: `tests/pos/i19198.scala`.
2153+
initInfo
2154+
else cases.map(_.notNullInfo).reduce(_.alt(_))
2155+
21562156
def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
21572157
var caseCtx = ctx
21582158
var wideSelType = wideSelType0
@@ -2241,7 +2241,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22412241
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
22422242
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
22432243
val expr1 = typed(tree.expr, bind1.symbol.info)
2244-
assignType(cpy.Labeled(tree)(bind1, expr1))
2244+
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
22452245
}
22462246

22472247
/** Type a case of a type match */
@@ -2291,7 +2291,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22912291
// Hence no adaptation is possible, and we assume WildcardType as prototype.
22922292
(from, proto)
22932293
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
2294-
assignType(cpy.Return(tree)(expr1, from))
2294+
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
22952295
end typedReturn
22962296

22972297
def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
@@ -2332,7 +2332,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23322332
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
23332333
untpd.Block(makeCanThrow(capabilityProof), expr)
23342334

2335-
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
2335+
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
2336+
var nnInfo = NotNullInfo.empty
23362337
val expr2 :: cases2x = harmonic(harmonize, pt) {
23372338
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
23382339
// uses the types of patterns in `tree.cases` to determine the capabilities.
@@ -2344,18 +2345,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23442345
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
23452346
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
23462347
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
2347-
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)
2348+
2349+
// Since we don't know at which point the the exception is thrown in the body,
2350+
// we have to collect any reference that is once retracted.
2351+
nnInfo = expr1.notNullInfo.retractedInfo
2352+
2353+
val casesCtx = ctx.addNotNullInfo(nnInfo)
23482354
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
23492355
expr1 :: cases1
23502356
}: @unchecked
23512357
val cases2 = cases2x.asInstanceOf[List[CaseDef]]
23522358

2353-
var nni = expr2.notNullInfo.retractedInfo
2354-
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
2355-
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
2356-
nni = nni.seq(finalizer1.notNullInfo)
2357-
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
2358-
}
2359+
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
2360+
// Therefore, the code in the finalizer and after the try block can only rely on the retracted
2361+
// info from the cases' body.
2362+
if cases2.nonEmpty then
2363+
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
2364+
2365+
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
2366+
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
2367+
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)
23592368

23602369
def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
23612370
val cases: List[untpd.CaseDef] = tree.handler match
@@ -2369,15 +2378,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23692378
def typedThrow(tree: untpd.Throw)(using Context): Tree =
23702379
val expr1 = typed(tree.expr, defn.ThrowableType)
23712380
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
2372-
val res = Throw(expr1).withSpan(tree.span)
2381+
var res = Throw(expr1).withSpan(tree.span)
23732382
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
23742383
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
23752384
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
2376-
Typed(res,
2385+
res = Typed(res,
23772386
TypeTree(
23782387
AnnotatedType(res.tpe,
23792388
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
2380-
else res
2389+
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)
23812390

23822391
def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
23832392
val elemProto = pt.stripNull().elemType match {
@@ -2842,6 +2851,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28422851
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
28432852
postProcessInfo(vdef1, sym)
28442853
vdef1.setDefTree
2854+
val nnInfo = rhs1.notNullInfo
2855+
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
28452856
}
28462857

28472858
private def retractDefDef(sym: Symbol)(using Context): Tree =

tests/explicit-nulls/neg/i21380b.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,22 @@ def test3(i: Int) =
1818
i match
1919
case 1 if x != null => ()
2020
case _ => x = " "
21+
x.trim() // ok
22+
23+
def test4(i: Int) =
24+
var x: String | Null = null
25+
var y: String | Null = null
26+
i match
27+
case 1 => x = "1"
28+
case _ => y = " "
29+
x.trim() // error
30+
31+
def test5(i: Int): String =
32+
var x: String | Null = null
33+
var y: String | Null = null
34+
i match
35+
case 1 => x = "1"
36+
case _ =>
37+
y = " "
38+
return y
2139
x.trim() // ok

tests/explicit-nulls/neg/i21380c.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test4: Int =
3232
case npe: NullPointerException => x = ""
3333
case _ => x = ""
3434
x.length // error
35-
// Although the catch block here is exhaustive,
36-
// it is possible that the exception is thrown and not caught.
37-
// Therefore, the code after the try block can only rely on the retracted info.
35+
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
36+
// and some exceptions are thrown and not caught. Therefore, the code in the finalizer and
37+
// after the try block can only rely on the retracted info from the cases' body.
3838

3939
def test5: Int =
4040
var x: String | Null = null

0 commit comments

Comments
 (0)