Skip to content

Commit 9445d16

Browse files
committed
Relax overriding by stripping nulls deeply
1 parent 2ef89b2 commit 9445d16

File tree

7 files changed

+143
-66
lines changed

7 files changed

+143
-66
lines changed

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,41 @@ import Types._
99
/** Defines operations on nullable types and tree. */
1010
object NullOpsDecorator:
1111

12+
private class StripNullsMap(isDeep: Boolean)(using Context) extends TypeMap:
13+
def strip(tp: Type): Type = tp match
14+
case tp @ OrType(lhs, rhs) =>
15+
val llhs = this(lhs)
16+
val rrhs = this(rhs)
17+
if rrhs.isNullType then llhs
18+
else if llhs.isNullType then rrhs
19+
else derivedOrType(tp, llhs, rrhs)
20+
case tp @ AndType(tp1, tp2) =>
21+
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
22+
// since `stripNull((A | Null) & B)` would produce the wrong
23+
// result `(A & B) | Null`.
24+
val tp1s = this(tp1)
25+
val tp2s = this(tp2)
26+
if isDeep || (tp1s ne tp1) && (tp2s ne tp2) then
27+
derivedAndType(tp, tp1s, tp2s)
28+
else tp
29+
case tp: TypeBounds =>
30+
mapOver(tp)
31+
case _ => tp
32+
33+
def stripOver(tp: Type): Type = tp match
34+
case appTp @ AppliedType(tycon, targs) =>
35+
derivedAppliedType(appTp, tycon, targs.map(this))
36+
case ptp: PolyType =>
37+
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
38+
case mtp: MethodType =>
39+
mapOver(mtp)
40+
case _ => strip(tp)
41+
42+
override def apply(tp: Type): Type =
43+
if isDeep then stripOver(tp) else strip(tp)
44+
45+
end StripNullsMap
46+
1247
extension (self: Type)
1348
/** Syntactically strips the nullability from this type.
1449
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
@@ -17,38 +52,30 @@ object NullOpsDecorator:
1752
* The type will not be changed if explicit-nulls is not enabled.
1853
*/
1954
def stripNull(using Context): Type = {
20-
def strip(tp: Type): Type =
21-
val tpWiden = tp.widenDealias
22-
val tpStripped = tpWiden match {
23-
case tp @ OrType(lhs, rhs) =>
24-
val llhs = strip(lhs)
25-
val rrhs = strip(rhs)
26-
if rrhs.isNullType then llhs
27-
else if llhs.isNullType then rrhs
28-
else tp.derivedOrType(llhs, rrhs)
29-
case tp @ AndType(tp1, tp2) =>
30-
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
31-
// since `stripNull((A | Null) & B)` would produce the wrong
32-
// result `(A & B) | Null`.
33-
val tp1s = strip(tp1)
34-
val tp2s = strip(tp2)
35-
if (tp1s ne tp1) && (tp2s ne tp2) then
36-
tp.derivedAndType(tp1s, tp2s)
37-
else tp
38-
case tp @ TypeBounds(lo, hi) =>
39-
tp.derivedTypeBounds(strip(lo), strip(hi))
40-
case tp => tp
41-
}
42-
if tpStripped ne tpWiden then tpStripped else tp
43-
44-
if ctx.explicitNulls then strip(self) else self
55+
if ctx.explicitNulls then
56+
val selfw = self.widenDealias
57+
val selfws = new StripNullsMap(false)(selfw)
58+
if selfws ne selfw then selfws else self
59+
else self
4560
}
4661

4762
/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
4863
def isNullableUnion(using Context): Boolean = {
4964
val stripped = self.stripNull
5065
stripped ne self
5166
}
67+
68+
/** Strips nulls from this type deeply.
69+
* Compaired to `stripNull`, `stripNullsDeep` will apply `stripNull` to
70+
* each member of function types as well.
71+
*/
72+
def stripNullsDeep(using Context): Type =
73+
if ctx.explicitNulls then
74+
val selfw = self.widenDealias
75+
val selfws = new StripNullsMap(true)(selfw)
76+
if selfws ne selfw then selfws else self
77+
else self
78+
5279
end extension
5380

5481
import ast.tpd._

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,8 +1112,10 @@ object Types {
11121112
*/
11131113
def matches(that: Type)(using Context): Boolean = {
11141114
record("matches")
1115+
val thisTp1 = this.stripNullsDeep
1116+
val thatTp1 = that.stripNullsDeep
11151117
withoutMode(Mode.SafeNulls)(
1116-
TypeComparer.matchesType(this, that, relaxed = !ctx.phase.erasedTypes))
1118+
TypeComparer.matchesType(thisTp1, thatTp1, relaxed = !ctx.phase.erasedTypes))
11171119
}
11181120

11191121
/** This is the same as `matches` except that it also matches => T with T and

compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package transform
55
import core._
66
import Flags._, Symbols._, Contexts._, Scopes._, Decorators._, Types.Type
77
import NameKinds.DefaultGetterName
8+
import NullOpsDecorator._
89
import collection.mutable
910
import collection.immutable.BitSet
1011
import scala.annotation.tailrec
@@ -215,15 +216,20 @@ object OverridingPairs:
215216
}
216217
)
217218
else
218-
// releaxed override check for explicit nulls if one of the symbols is Java defined,
219-
// force `Null` being a subtype of reference types during override checking
220-
val relaxedCtxForNulls =
219+
def matchNullaryLoosely = member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack
220+
// default getters are not checked for compatibility
221+
member.name.is(DefaultGetterName) || {
221222
if ctx.explicitNulls && (member.is(JavaDefined) || other.is(JavaDefined)) then
222-
ctx.retractMode(Mode.SafeNulls)
223-
else ctx
224-
member.name.is(DefaultGetterName) // default getters are not checked for compatibility
225-
|| memberTp.overrides(otherTp,
226-
member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack
227-
)(using relaxedCtxForNulls)
223+
// releaxed override check for explicit nulls if one of the symbols is Java defined,
224+
// force `Null` being a subtype of reference types during override checking.
225+
// `stripNullsDeep` is used here because we may encounter type parameters
226+
// (`T | Null` is not a subtype of `T` even if we retract Mode.SafeNulls).
227+
val memberTp1 = memberTp.stripNullsDeep
228+
val otherTp1 = otherTp.stripNullsDeep
229+
withoutMode(Mode.SafeNulls)(
230+
memberTp1.overrides(otherTp1, matchNullaryLoosely))
231+
else
232+
memberTp.overrides(otherTp, matchNullaryLoosely)
233+
}
228234

229235
end OverridingPairs

compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Names._
1313
import StdNames._
1414
import NameOps._
1515
import NameKinds._
16+
import NullOpsDecorator._
1617
import ResolveSuper._
1718
import reporting.IllegalSuperAccessor
1819

@@ -110,11 +111,13 @@ object ResolveSuper {
110111
// Since the super class can be Java defined,
111112
// we use releaxed overriding check for explicit nulls if one of the symbols is Java defined.
112113
// This forces `Null` being a subtype of reference types during override checking.
113-
val relaxedCtxForNulls =
114-
if ctx.explicitNulls && (sym.is(JavaDefined) || acc.is(JavaDefined)) then
115-
ctx.retractMode(Mode.SafeNulls)
116-
else ctx
117-
if (!(otherTp.overrides(accTp, matchLoosely = true)(using relaxedCtxForNulls)))
114+
val overridesSuper = if ctx.explicitNulls && (sym.is(JavaDefined) || acc.is(JavaDefined)) then
115+
val otherTp1 = otherTp.stripNullsDeep
116+
val accTp1 = accTp.stripNullsDeep
117+
withoutMode(Mode.SafeNulls)(otherTp1.overrides(accTp1, matchLoosely = true))
118+
else
119+
otherTp.overrides(accTp, matchLoosely = true)
120+
if !overridesSuper then
118121
report.error(IllegalSuperAccessor(base, memberName, targetName, acc, accTp, other.symbol, otherTp), base.srcPos)
119122

120123
bcs = bcs.tail
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Unboxed option type using unions + null + opaque.
2+
// Relies on the fact that Null is not a subtype of AnyRef.
3+
// Test suggested by Sébastien Doeraene.
4+
5+
object Nullables {
6+
opaque type Nullable[+A <: AnyRef] = A | Null // disjoint by construction!
7+
8+
object Nullable:
9+
def apply[A <: AnyRef](x: A | Null): Nullable[A] = x
10+
11+
def some[A <: AnyRef](x: A): Nullable[A] = x
12+
def none: Nullable[Nothing] = null
13+
14+
extension [A <: AnyRef](x: Nullable[A])
15+
def isEmpty: Boolean = x == null
16+
def get: A | Null = x
17+
18+
extension [A <: AnyRef, B <: AnyRef](x: Nullable[A])
19+
def flatMap(f: A => Nullable[B]): Nullable[B] =
20+
if (x == null) null
21+
else f(x)
22+
23+
def map(f: A => B): Nullable[B] = x.flatMap(f)
24+
25+
def test1 =
26+
val s1: Nullable[String] = Nullable("hello")
27+
val s2: Nullable[String] = "world"
28+
val s3: Nullable[String] = Nullable.none
29+
val s4: Nullable[String] = null
30+
31+
s1.isEmpty
32+
s1.flatMap((x) => true)
33+
34+
assert(s2 != null)
35+
}
36+
37+
def test2 =
38+
import Nullables._
39+
40+
val s1: Nullable[String] = Nullable("hello")
41+
val s2: Nullable[String] = Nullable.none
42+
val s3: Nullable[String] = null // error: don't leak nullable union
43+
44+
s1.isEmpty
45+
s1.flatMap((x) => Nullable(true))
46+
47+
assert(s2 == null) // error

tests/explicit-nulls/pos/opaque-nullable.scala

Lines changed: 0 additions & 26 deletions
This file was deleted.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Testing relaxed overriding check for explicit nulls.
2+
// The relaxed check is only enabled if one of the members is Java defined.
3+
4+
import java.util.Comparator
5+
6+
class C1[T <: AnyRef] extends Ordering[T]:
7+
override def compare(o1: T, o2: T): Int = 0
8+
9+
// The following overriding is not allowed, because `compare`
10+
// has already been declared in Scala class `Ordering`.
11+
// class C2[T <: AnyRef] extends Ordering[T]:
12+
// override def compare(o1: T | Null, o2: T | Null): Int = 0
13+
14+
class D1[T <: AnyRef] extends Comparator[T]:
15+
override def compare(o1: T, o2: T): Int = 0
16+
17+
class D2[T <: AnyRef] extends Comparator[T]:
18+
override def compare(o1: T | Null, o2: T | Null): Int = 0

0 commit comments

Comments
 (0)