Skip to content

Commit 00430c0

Browse files
committed
Treat asserted set of terminated NotNullInfo as universal set; fix test
1 parent 158af7d commit 00430c0

File tree

5 files changed

+40
-19
lines changed

5 files changed

+40
-19
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,13 +777,13 @@ object Contexts {
777777

778778
extension (c: Context)
779779
def addNotNullInfo(info: NotNullInfo) =
780-
c.withNotNullInfos(c.notNullInfos.extendWith(info))
780+
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c
781781

782782
def addNotNullRefs(refs: Set[TermRef]) =
783-
c.addNotNullInfo(NotNullInfo(refs, Set()))
783+
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c
784784

785785
def withNotNullInfos(infos: List[NotNullInfo]): Context =
786-
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
786+
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)
787787

788788
def relaxedOverrideContext: Context =
789789
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)

compiler/src/dotty/tools/dotc/typer/Nullables.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ object Nullables:
235235
*/
236236
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
237237
case info :: infos1 =>
238-
if info.asserted != null && info.asserted.contains(ref) then true
238+
if info.asserted == null || info.asserted.contains(ref) then true
239239
else if info.retracted.contains(ref) then false
240240
else infos1.impliesNotNull(ref)
241241
case _ =>
@@ -315,8 +315,8 @@ object Nullables:
315315
extension (tree: Tree)
316316

317317
/* The `tree` with added nullability attachment */
318-
def withNotNullInfo(info: NotNullInfo): tree.type =
319-
if !info.isEmpty then tree.putAttachment(NNInfo, info)
318+
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
319+
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
320320
tree
321321

322322
/* Collect the nullability info from parts of `tree` */
@@ -335,13 +335,15 @@ object Nullables:
335335

336336
/* The nullability info of `tree` */
337337
def notNullInfo(using Context): NotNullInfo =
338-
val tree1 = stripInlined(tree)
339-
tree1.getAttachment(NNInfo) match
340-
case Some(info) if !ctx.erasedTypes => info
341-
case _ =>
342-
val nnInfo = tree1.collectNotNullInfo
343-
tree1.withNotNullInfo(nnInfo)
344-
nnInfo
338+
if !ctx.explicitNulls then NotNullInfo.empty
339+
else
340+
val tree1 = stripInlined(tree)
341+
tree1.getAttachment(NNInfo) match
342+
case Some(info) if !ctx.erasedTypes => info
343+
case _ =>
344+
val nnInfo = tree1.collectNotNullInfo
345+
tree1.withNotNullInfo(nnInfo)
346+
nnInfo
345347

346348
/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
347349
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,6 +2849,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28492849
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
28502850
postProcessInfo(vdef1, sym)
28512851
vdef1.setDefTree
2852+
val nnInfo = rhs1.notNullInfo
2853+
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
28522854
}
28532855

28542856
private def retractDefDef(sym: Symbol)(using Context): Tree =
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class C(val x: Int, val next: C | Null)
2+
3+
def test1(x: String | Null, c: C | Null): Int =
4+
return 0
5+
// We know that the following code is unreachable,
6+
// so we can treat `x`, `c`, and any variable/path non-nullable.
7+
x.length + c.next.x
8+
9+
def test2(x: String | Null, c: C | Null): Int =
10+
throw new Exception()
11+
x.length + c.next.x
12+
13+
def fail(): Nothing = ???
14+
15+
def test3(x: String | Null, c: C | Null): Int =
16+
fail()
17+
x.length + c.next.x

tests/explicit-nulls/unsafe-common/unsafe-overload.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class S {
1616
val o: O = ???
1717

1818
locally {
19-
def h1(hh: String => String) = ???
20-
def h2(hh: Array[String] => Array[String]) = ???
19+
def h1(hh: String => String): Unit = ???
20+
def h2(hh: Array[String] => Array[String]): Unit = ???
2121
def f1(x: String | Null): String | Null = ???
2222
def f2(x: Array[String | Null]): Array[String | Null] = ???
2323

@@ -29,10 +29,10 @@ class S {
2929
}
3030

3131
locally {
32-
def h1(hh: String | Null => String | Null) = ???
33-
def h2(hh: Array[String | Null] => Array[String | Null]) = ???
32+
def h1(hh: String | Null => String | Null): Unit = ???
33+
def h2(hh: Array[String | Null] => Array[String | Null]): Unit = ???
3434
def g1(x: String): String = ???
35-
def g2(x: Array[String]): Array[String] = ???
35+
def g2(x: Array[String]): Array[String] = ???
3636

3737
h1(g1) // error
3838
h1(o.g) // error
@@ -51,7 +51,7 @@ class S {
5151

5252
locally {
5353
def g1(x: String): String = ???
54-
def g2(x: Array[String]): Array[String] = ???
54+
def g2(x: Array[String]): Array[String] = ???
5555

5656
o.i(g1) // error
5757
o.i(g2) // error

0 commit comments

Comments
 (0)