Skip to content

Commit 1495c5c

Browse files
committed
Add a synthetic RefEq trait for reference equality
Since `Null` is no longer a subtype of `AnyRef`, and `AnyRef` defines the `eq` and `neq` method, we were no longer able to do comparisons like `null.eq(s)`. Fix this by moving `eq` and `neq` to a new synthetic trait called `RefEq`, which is a supertype of both `Null` and `AnyRef`.
1 parent e09f711 commit 1495c5c

File tree

8 files changed

+167
-28
lines changed

8 files changed

+167
-28
lines changed

compiler/src/dotty/tools/backend/jvm/scalaPrimitives.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,23 @@ class DottyPrimitives(ctx: Context) {
156156
addPrimitive(defn.Any_##, HASH)
157157

158158
// java.lang.Object
159-
addPrimitive(defn.Object_eq, ID)
160-
addPrimitive(defn.Object_ne, NI)
161-
/* addPrimitive(defn.Any_==, EQ)
162-
addPrimitive(defn.Any_!=, NE)*/
163-
addPrimitive(defn.Object_synchronized, SYNCHRONIZED)
164-
/*addPrimitive(defn.Any_isInstanceOf, IS)
165-
addPrimitive(defn.Any_asInstanceOf, AS)*/
159+
if (ctx.settings.YexplicitNulls.value) {
160+
// scala.RefEq
161+
addPrimitive(defn.RefEq_eq, ID)
162+
addPrimitive(defn.RefEq_ne, NI)
163+
164+
// java.lang.Object
165+
addPrimitive(defn.Object_synchronized, SYNCHRONIZED)
166+
} else {
167+
// java.lang.Object
168+
addPrimitive(defn.Object_eq, ID)
169+
addPrimitive(defn.Object_ne, NI)
170+
/* addPrimitive(defn.Any_==, EQ)
171+
addPrimitive(defn.Any_!=, NE) */
172+
addPrimitive(defn.Object_synchronized, SYNCHRONIZED)
173+
/*addPrimitive(defn.Any_isInstanceOf, IS)
174+
addPrimitive(defn.Any_asInstanceOf, AS)*/
175+
}
166176

167177
// java.lang.String
168178
addPrimitive(defn.String_+, CONCAT)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,12 @@ class Definitions {
286286
lazy val ObjectClass: ClassSymbol = {
287287
val cls = ctx.requiredClass("java.lang.Object")
288288
assert(!cls.isCompleted, "race for completing java.lang.Object")
289-
cls.info = ClassInfo(cls.owner.thisType, cls, AnyClass.typeRef :: Nil, newScope)
289+
val parents = if (ctx.settings.YexplicitNulls.value) {
290+
AnyType :: RefEqType :: Nil
291+
} else {
292+
AnyType :: Nil
293+
}
294+
cls.info = ClassInfo(cls.owner.thisType, cls, parents, newScope)
290295
cls.setFlag(NoInits)
291296

292297
// The companion object doesn't really exist, `NoType` is the general
@@ -303,8 +308,17 @@ class Definitions {
303308
lazy val AnyRefAlias: TypeSymbol = enterAliasType(tpnme.AnyRef, ObjectType)
304309
def AnyRefType: TypeRef = AnyRefAlias.typeRef
305310

306-
lazy val Object_eq: TermSymbol = enterMethod(ObjectClass, nme.eq, methOfAnyRef(BooleanType), Final)
307-
lazy val Object_ne: TermSymbol = enterMethod(ObjectClass, nme.ne, methOfAnyRef(BooleanType), Final)
311+
// TODO(abeln): modify usage sites to use `RefEq_eq/ne` once we migrate to explicit nulls?
312+
lazy val Object_eq: TermSymbol = if (ctx.settings.YexplicitNulls.value) {
313+
RefEq_eq
314+
} else {
315+
enterMethod(ObjectClass, nme.eq, methOfAnyRef(BooleanType), Final)
316+
}
317+
lazy val Object_ne: TermSymbol = if (ctx.settings.YexplicitNulls.value) {
318+
RefEq_ne
319+
} else {
320+
enterMethod(ObjectClass, nme.ne, methOfAnyRef(BooleanType), Final)
321+
}
308322
lazy val Object_synchronized: TermSymbol = enterPolyMethod(ObjectClass, nme.synchronized_, 1,
309323
pt => MethodType(List(pt.paramRefs(0)), pt.paramRefs(0)), Final)
310324
lazy val Object_clone: TermSymbol = enterMethod(ObjectClass, nme.clone_, MethodType(Nil, ObjectType), Protected)
@@ -351,11 +365,40 @@ class Definitions {
351365
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyClass.typeRef))
352366
def NothingType: TypeRef = NothingClass.typeRef
353367
lazy val RuntimeNothingModuleRef: TermRef = ctx.requiredModuleRef("scala.runtime.Nothing")
368+
369+
/** `RefEq` is the trait defining the reference equality operators (`eq`, `neq`).
370+
* It's a supertype of both `AnyRef` (which is non-nullable) and `Null`.
371+
* With `RefEq`, we can compare `null` for reference equality a la `null eq foo`.
372+
* `RefEq` is just a marker trait and there's no corresponding class file, since it gets erased to `Object`.
373+
*/
374+
lazy val RefEqClass: ClassSymbol = {
375+
assert(ctx.settings.YexplicitNulls.value)
376+
enterCompleteClassSymbol(ScalaPackageClass, tpnme.RefEq, Trait, AnyClass.typeRef :: Nil)
377+
}
378+
def RefEqType: TypeRef = {
379+
assert(ctx.settings.YexplicitNulls.value)
380+
RefEqClass.typeRef
381+
}
382+
383+
lazy val RefEq_eq: TermSymbol = {
384+
assert(ctx.settings.YexplicitNulls.value)
385+
enterMethod(RefEqClass, nme.eq, MethodType(List(RefEqType), BooleanType), Final)
386+
}
387+
lazy val RefEq_ne: TermSymbol = {
388+
assert(ctx.settings.YexplicitNulls.value)
389+
enterMethod(RefEqClass, nme.ne, MethodType(List(RefEqType), BooleanType), Final)
390+
}
391+
392+
def RefEqMethods: List[TermSymbol] = {
393+
assert(ctx.settings.YexplicitNulls.value)
394+
List(RefEq_eq, RefEq_ne)
395+
}
396+
354397
lazy val NullClass: ClassSymbol = {
355398
val parents = if (ctx.settings.YexplicitNulls.value) {
356-
List(AnyClass.typeRef)
399+
List(AnyType, RefEqType)
357400
} else {
358-
List(ObjectClass.typeRef)
401+
List(ObjectType)
359402
}
360403
enterCompleteClassSymbol(ScalaPackageClass, tpnme.Null, AbstractFinal, parents)
361404
}
@@ -1205,7 +1248,14 @@ class Definitions {
12051248
lazy val UnqualifiedOwnerTypes: Set[NamedType] =
12061249
RootImportTypes.toSet[NamedType] ++ RootImportTypes.map(_.symbol.moduleClass.typeRef)
12071250

1208-
lazy val NotRuntimeClasses: Set[Symbol] = Set(AnyClass, AnyValClass, NullClass, NothingClass)
1251+
lazy val NotRuntimeClasses: Set[Symbol] = {
1252+
val classes: Set[Symbol] = Set(AnyClass, AnyValClass, NullClass, NothingClass)
1253+
if (ctx.settings.YexplicitNulls.value) {
1254+
classes + RefEqClass
1255+
} else {
1256+
classes
1257+
}
1258+
}
12091259

12101260
/** Classes that are known not to have an initializer irrespective of
12111261
* whether NoInits is set. Note: FunctionXXLClass is in this set
@@ -1400,13 +1450,20 @@ class Definitions {
14001450
def isValueSubClass(sym1: Symbol, sym2: Symbol): Boolean =
14011451
valueTypeEnc(sym2.asClass.name) % valueTypeEnc(sym1.asClass.name) == 0
14021452

1403-
lazy val specialErasure: SimpleIdentityMap[Symbol, ClassSymbol] =
1404-
SimpleIdentityMap.Empty[Symbol]
1405-
.updated(AnyClass, ObjectClass)
1406-
.updated(AnyValClass, ObjectClass)
1407-
.updated(SingletonClass, ObjectClass)
1408-
.updated(TupleClass, ObjectClass)
1409-
.updated(NonEmptyTupleClass, ProductClass)
1453+
lazy val specialErasure: SimpleIdentityMap[Symbol, ClassSymbol] = {
1454+
val idMap =
1455+
SimpleIdentityMap.Empty[Symbol]
1456+
.updated(AnyClass, ObjectClass)
1457+
.updated(AnyValClass, ObjectClass)
1458+
.updated(SingletonClass, ObjectClass)
1459+
.updated(TupleClass, ObjectClass)
1460+
.updated(NonEmptyTupleClass, ProductClass)
1461+
if (ctx.settings.YexplicitNulls.value) {
1462+
idMap.updated(RefEqClass, ObjectClass)
1463+
} else {
1464+
idMap
1465+
}
1466+
}
14101467

14111468
// ----- Initialization ---------------------------------------------------
14121469

@@ -1426,7 +1483,7 @@ class Definitions {
14261483
SingletonClass,
14271484
EqualsPatternClass)
14281485

1429-
if (ctx.settings.YexplicitNulls.value) synth :+ JavaNullAlias
1486+
if (ctx.settings.YexplicitNulls.value) synth ++ List(JavaNullAlias, RefEqClass)
14301487
else synth
14311488
}
14321489

@@ -1435,8 +1492,14 @@ class Definitions {
14351492
OpsPackageClass)
14361493

14371494
/** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
1438-
lazy val syntheticCoreMethods: List[TermSymbol] =
1439-
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
1495+
lazy val syntheticCoreMethods: List[TermSymbol] = {
1496+
val methods = AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
1497+
if (ctx.settings.YexplicitNulls.value) {
1498+
methods ++ RefEqMethods
1499+
} else {
1500+
methods
1501+
}
1502+
}
14401503

14411504
lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet
14421505

compiler/src/dotty/tools/dotc/core/JavaNullInterop.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ object JavaNullInterop {
146146
!alreadyNullable && (tp match {
147147
case tp: TypeRef =>
148148
// We don't modify value types because they're non-nullable even in Java.
149-
// We don't modify `Any` because it's already nullable.
150-
!tp.symbol.isValueClass && !tp.isRef(defn.AnyClass)
149+
// We don't modify `Any` or `RefEq` because they're already nullable.
150+
!tp.symbol.isValueClass && !tp.isRef(defn.AnyClass) && !tp.isRef(defn.RefEqClass)
151151
case _ => true
152152
})
153153
}

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ object StdNames {
160160
scala.List(Byte, Char, Short, Int, Long, Float, Double, Boolean, Unit)
161161

162162
// some types whose companions we utilize
163+
final val RefEq: N = "RefEq"
163164
final val AnyRef: N = "AnyRef"
164165
final val Array: N = "Array"
165166
final val List: N = "List"

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,13 @@ object Types {
563563
val rsym = r.classSymbol
564564
if (lsym isSubClass rsym) rsym
565565
else if (rsym isSubClass lsym) lsym
566+
else if (ctx.settings.YexplicitNulls.value && this.isNullableUnion) {
567+
val OrType(left, _) = this.normNullableUnion
568+
// If `left` is a reference type, then the class LUB of `left | Null` is `RefEq`.
569+
// This is another one-of case that keeps this method sound, but not complete.
570+
if (left.classSymbol isSubClass defn.ObjectClass) defn.RefEqClass
571+
else NoSymbol
572+
}
566573
else NoSymbol
567574
case _ =>
568575
NoSymbol

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,14 @@ class Erasure extends Phase with DenotTransformer {
5454
// After erasure, all former Any members are now Object members
5555
val ClassInfo(pre, _, ps, decls, selfInfo) = ref.info
5656
val extendedScope = decls.cloneScope
57-
for (decl <- defn.AnyClass.classInfo.decls)
58-
if (!decl.isConstructor) extendedScope.enter(decl)
57+
def addDecls(cls: ClassSymbol) = {
58+
for (decl <- cls.classInfo.decls)
59+
if (!decl.isConstructor) extendedScope.enter(decl)
60+
}
61+
addDecls(defn.AnyClass)
62+
if (ctx.settings.YexplicitNulls.value) {
63+
addDecls(defn.RefEqClass)
64+
}
5965
ref.copySymDenotation(
6066
info = transformInfo(ref.symbol,
6167
ClassInfo(pre, defn.ObjectClass, ps, extendedScope, selfInfo))
@@ -68,7 +74,14 @@ class Erasure extends Phase with DenotTransformer {
6874
defn.ObjectClass.primaryConstructor
6975
else oldSymbol
7076
val oldOwner = ref.owner
71-
val newOwner = if (oldOwner eq defn.AnyClass) defn.ObjectClass else oldOwner
77+
val newOwner = {
78+
val shouldBeObject = if (ctx.settings.YexplicitNulls.value) {
79+
(oldOwner eq defn.AnyClass) || (oldOwner eq defn.RefEqClass)
80+
} else {
81+
oldOwner eq defn.AnyClass
82+
}
83+
if (shouldBeObject) defn.ObjectClass else oldOwner
84+
}
7285
val oldInfo = ref.info
7386
val newInfo = transformInfo(oldSymbol, oldInfo)
7487
val oldFlags = ref.flags

tests/explicit-nulls/pos/ref-eq.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
// Test that the synthetic trait `RefEq` is usable.
3+
4+
class Test {
5+
val x: RefEq = null
6+
val y: RefEq = "hello"
7+
8+
x.eq(y)
9+
y.eq(x)
10+
x.eq(x)
11+
y.eq(y)
12+
13+
x.ne(y)
14+
y.ne(x)
15+
x.ne(x)
16+
y.ne(y)
17+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
assert(null.eq(null))
6+
assert(!null.ne(null))
7+
8+
assert(!null.eq("hello"))
9+
assert(null.ne("hello"))
10+
11+
assert(!null.eq(4))
12+
assert(null.ne(4))
13+
14+
assert(!"hello".eq(null))
15+
assert("hello".ne(null))
16+
17+
assert(!4.eq(null))
18+
assert(4.ne(null))
19+
20+
val x: String|Null = null
21+
assert(x.eq(null))
22+
assert(!x.ne(null))
23+
24+
val x2: AnyRef|Null = "world"
25+
assert(!x2.eq(null))
26+
assert(x2.ne(null))
27+
}
28+
}

0 commit comments

Comments
 (0)