Skip to content

Commit 7eed861

Browse files
committed
implement readResolve in terms of fromOrdinalDollar method
value cases are collected in EnumCaseCount attachment during nextOrdinal, and an attachment DefinesEnumLookupMethods determines when the final case has been reached, at which point the methods can be generated that depend on the cases.
1 parent f2018f0 commit 7eed861

File tree

6 files changed

+122
-41
lines changed

6 files changed

+122
-41
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ object desugar {
476476
val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase)
477477
if (enumCases.isEmpty)
478478
report.error(EnumerationsShouldNotBeEmpty(cdef), namePos)
479+
else
480+
enumCases.last.pushAttachment(DesugarEnums.DefinesEnumLookupMethods, ())
479481
val enumCompanionRef = TermRefTree()
480482
val enumImport =
481483
Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_)))
@@ -568,7 +570,7 @@ object desugar {
568570
// Note: copy default parameters need @uncheckedVariance; see
569571
// neg/t1843-variances.scala for a test case. The test would give
570572
// two errors without @uncheckedVariance, one of them spurious.
571-
val caseClassMeths = {
573+
val (caseClassMeths, enumScaffolding) = {
572574
def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) =
573575
DefDef(name, Nil, Nil, tpt, rhs).withMods(synthetic)
574576

@@ -586,9 +588,11 @@ object desugar {
586588
yield syntheticProperty(selName, caseParams(i).tpt,
587589
Select(This(EmptyTypeIdent), caseParams(i).name))
588590

589-
def enumMeths =
590-
if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil
591-
else Nil
591+
def enumCaseMeths =
592+
if isEnumCase then
593+
val (ordinal, scaffolding) = nextOrdinal(className, CaseKind.Class, definesEnumLookupMethods(cdef))
594+
(ordinalMethLit(ordinal) :: enumLabelLit(className.toString) :: Nil, scaffolding)
595+
else (Nil, Nil)
592596
def copyMeths = {
593597
val hasRepeatedParam = constrVparamss.exists(_.exists {
594598
case ValDef(_, tpt, _) => isRepeated(tpt)
@@ -607,8 +611,9 @@ object desugar {
607611
}
608612

609613
if (isCaseClass)
610-
copyMeths ::: enumMeths ::: productElemMeths
611-
else Nil
614+
val (enumMeths, enumScaffolding) = enumCaseMeths
615+
(copyMeths ::: enumMeths ::: productElemMeths, enumScaffolding)
616+
else (Nil, Nil)
612617
}
613618

614619
var parents1 = parents
@@ -806,7 +811,7 @@ object desugar {
806811
case _ =>
807812
}
808813

809-
flatTree(cdef1 :: companions ::: implicitWrappers)
814+
flatTree(cdef1 :: companions ::: implicitWrappers ::: enumScaffolding)
810815
}.reporting(i"desugared: $result", Printers.desugar)
811816

812817
/** Expand
@@ -859,7 +864,7 @@ object desugar {
859864
else if (isEnumCase) {
860865
typeParamIsReferenced(enumClass.typeParams, Nil, Nil, impl.parents)
861866
// used to check there are no illegal references to enum's type parameters in parents
862-
expandEnumModule(moduleName, impl, mods, mdef.span)
867+
expandEnumModule(moduleName, impl, mods, definesEnumLookupMethods(mdef), mdef.span)
863868
}
864869
else {
865870
val clsName = moduleName.moduleClassName
@@ -987,6 +992,12 @@ object desugar {
987992

988993
private def inventTypeName(tree: Tree)(using Context): String = typeNameExtractor("", tree)
989994

995+
/**This will check if this def tree is marked to define enum lookup methods,
996+
* this is not recommended to call more than once per tree
997+
*/
998+
private def definesEnumLookupMethods(ddef: DefTree): Boolean =
999+
ddef.removeAttachment(DefinesEnumLookupMethods).isDefined
1000+
9901001
/** val p1, ..., pN: T = E
9911002
* ==>
9921003
* makePatDef[[val p1: T1 = E]]; ...; makePatDef[[val pN: TN = E]]
@@ -998,11 +1009,15 @@ object desugar {
9981009
def patDef(pdef: PatDef)(using Context): Tree = flatTree {
9991010
val PatDef(mods, pats, tpt, rhs) = pdef
10001011
if (mods.isEnumCase)
1001-
pats map {
1002-
case id: Ident =>
1003-
expandSimpleEnumCase(id.name.asTermName, mods,
1012+
def expand(id: Ident, definesLookups: Boolean) =
1013+
expandSimpleEnumCase(id.name.asTermName, mods, definesLookups,
10041014
Span(id.span.start, id.span.end, id.span.start))
1005-
}
1015+
1016+
val ids = pats.asInstanceOf[List[Ident]]
1017+
if definesEnumLookupMethods(pdef) then
1018+
ids.init.map(expand(_, false)) ::: expand(ids.last, true) :: Nil
1019+
else
1020+
ids.map(expand(_, false))
10061021
else {
10071022
val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt))
10081023
pats1 map (makePatDef(pdef, mods, _, rhs))

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ object DesugarEnums {
2020
val Simple, Object, Class: Value = Value
2121
}
2222

23-
/** Attachment containing the number of enum cases and the smallest kind that was seen so far. */
24-
val EnumCaseCount: Property.Key[(Int, DesugarEnums.CaseKind.Value)] = Property.Key()
23+
/** Attachment containing the number of enum cases, the smallest kind that was seen so far,
24+
* and a list of all the value cases with their ordinals.
25+
*/
26+
val EnumCaseCount: Property.Key[(Int, CaseKind.Value, List[(Int, TermName)])] = Property.Key()
27+
28+
/** Attachment signalling that when this definition is desugared, it should add any additional
29+
* lookup methods for enums.
30+
*/
31+
val DefinesEnumLookupMethods: Property.Key[Unit] = Property.Key()
2532

2633
/** The enumeration class that belongs to an enum case. This works no matter
2734
* whether the case is still in the enum class or it has been transferred to the
@@ -122,6 +129,21 @@ object DesugarEnums {
122129
valueOfDef :: Nil
123130
}
124131

132+
private def enumLookupMethods(cases: List[(Int, TermName)])(using Context): List[Tree] =
133+
if isJavaEnum || cases.isEmpty then Nil
134+
else
135+
val defaultCase =
136+
val ord = Ident(nme.ordinal)
137+
val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil)))
138+
CaseDef(ord, EmptyTree, err)
139+
val valueCases = cases.map((i, name) =>
140+
CaseDef(Literal(Constant(i)), EmptyTree, Ident(name))
141+
) ::: defaultCase :: Nil
142+
val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil),
143+
rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases))
144+
.withFlags(Synthetic | Private)
145+
fromOrdinalDef :: Nil
146+
125147
/** A creation method for a value of enum type `E`, which is defined as follows:
126148
*
127149
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
@@ -256,16 +278,22 @@ object DesugarEnums {
256278
* - scaffolding containing the necessary definitions for singleton enum cases
257279
* unless that scaffolding was already generated by a previous call to `nextEnumKind`.
258280
*/
259-
def nextOrdinal(kind: CaseKind.Value)(using Context): (Int, List[Tree]) = {
260-
val (count, seenKind) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class))
261-
val minKind = if (kind < seenKind) kind else seenKind
262-
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
263-
val scaffolding =
281+
def nextOrdinal(name: Name, kind: CaseKind.Value, definesLookups: Boolean)(using Context): (Int, List[Tree]) = {
282+
val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, Nil))
283+
val minKind = if kind < seenKind then kind else seenKind
284+
val cases = name match
285+
case name: TermName => (ordinal, name) :: seenCases
286+
case _ => seenCases
287+
ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, cases))
288+
val scaffolding0 =
264289
if (kind >= seenKind) Nil
265290
else if (kind == CaseKind.Object) enumScaffolding
266291
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
267292
else enumScaffolding :+ enumValueCreator
268-
(count, scaffolding)
293+
val scaffolding =
294+
if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse)
295+
else scaffolding0
296+
(ordinal, scaffolding)
269297
}
270298

271299
def param(name: TermName, typ: Type)(using Context) =
@@ -286,13 +314,13 @@ object DesugarEnums {
286314
enumLabelMeth(Literal(Constant(name)))
287315

288316
/** Expand a module definition representing a parameterless enum case */
289-
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
317+
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree = {
290318
assert(impl.body.isEmpty)
291319
if (!enumClass.exists) EmptyTree
292320
else if (impl.parents.isEmpty)
293-
expandSimpleEnumCase(name, mods, span)
321+
expandSimpleEnumCase(name, mods, definesLookups, span)
294322
else {
295-
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
323+
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups)
296324
val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
297325
val enumLabelDef = enumLabelLit(name.toString)
298326
val impl1 = cpy.Template(impl)(
@@ -305,15 +333,15 @@ object DesugarEnums {
305333
}
306334

307335
/** Expand a simple enum case */
308-
def expandSimpleEnumCase(name: TermName, mods: Modifiers, span: Span)(using Context): Tree =
336+
def expandSimpleEnumCase(name: TermName, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree =
309337
if (!enumClass.exists) EmptyTree
310338
else if (enumClass.typeParams.nonEmpty) {
311339
val parent = interpolatedEnumParent(span)
312340
val impl = Template(emptyConstructor, parent :: Nil, Nil, EmptyValDef, Nil)
313-
expandEnumModule(name, impl, mods, span)
341+
expandEnumModule(name, impl, mods, definesLookups, span)
314342
}
315343
else {
316-
val (tag, scaffolding) = nextOrdinal(CaseKind.Simple)
344+
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Simple, definesLookups)
317345
val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), Literal(Constant(name.toString))))
318346
val vdef = ValDef(name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue, span))
319347
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ object StdNames {
614614
val using: N = "using"
615615
val value: N = "value"
616616
val valueOf : N = "valueOf"
617+
val fromOrdinalDollar: N = "$fromOrdinal"
617618
val values: N = "values"
618619
val view_ : N = "view"
619620
val wait_ : N = "wait"
@@ -622,6 +623,7 @@ object StdNames {
622623
val WorksheetWrapper: N = "WorksheetWrapper"
623624
val wrap: N = "wrap"
624625
val writeReplace: N = "writeReplace"
626+
val readResolve: N = "readResolve"
625627
val zero: N = "zero"
626628
val zip: N = "zip"
627629
val nothingRuntimeClass: N = "scala.runtime.Nothing$"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,19 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
373373
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
374374
.exists
375375

376+
private def hasReadResolve(clazz: ClassSymbol)(using Context): Boolean =
377+
clazz.membersNamed(nme.readResolve)
378+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
379+
.exists
380+
376381
private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
377382
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
378383
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
379384

385+
private def readResolveDef(clazz: ClassSymbol)(using Context): TermSymbol =
386+
newSymbol(clazz, nme.readResolve, Method | Private | Synthetic,
387+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
388+
380389
/** If this is a static object `Foo`, add the method:
381390
*
382391
* private def writeReplace(): AnyRef =
@@ -405,22 +414,22 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
405414
/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
406415
* and not deriving from `java.lang.Enum` add the method:
407416
*
408-
* private def writeReplace(): AnyRef =
409-
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
417+
* private def readResolve(): AnyRef =
418+
* MyEnum.$fromOrdinal(this.ordinal)
410419
*
411420
* unless an implementation already exists, otherwise do nothing.
412421
*/
413422
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
414423
if clazz.isEnumValueImplementation
415424
&& !clazz.derivesFrom(defn.JavaEnumClass)
416425
&& clazz.isSerializable
417-
&& !hasWriteReplace(clazz)
426+
&& !hasReadResolve(clazz)
418427
then
419428
List(
420-
DefDef(writeReplaceDef(clazz),
421-
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
422-
defn.EnumValueSerializationProxyConstructor,
423-
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
429+
DefDef(readResolveDef(clazz),
430+
_ => ref(clazz.owner.owner.sourceModule)
431+
.select(nme.fromOrdinalDollar)
432+
.appliedTo(This(clazz).select(nme.ordinal).ensureApplied))
424433
.withSpan(ctx.owner.span.focus))
425434
else
426435
Nil

tests/run/enums-serialization-compat.scala

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,49 @@ import java.io._
22
import scala.util.Using
33

44
enum JColor extends java.lang.Enum[JColor]:
5-
case Red
5+
case Red // java enum has magic JVM support
66

77
enum SColor:
8-
case Green
8+
case Green // simple case last
99

1010
enum SColorTagged[T]:
11-
case Blue extends SColorTagged[Unit]
11+
case Blue extends SColorTagged[Unit]
12+
case Rgb(r: Byte, g: Byte, b: Byte) extends SColorTagged[(Byte, Byte, Byte)] // mixing pattern kinds
13+
case Indigo extends SColorTagged[Unit]
14+
case Cmyk(c: Byte, m: Byte, y: Byte, k: Byte) extends SColorTagged[(Byte, Byte, Byte, Byte)] // class case last
15+
16+
enum Nucleobase:
17+
case A,C,G,T // patdef last
18+
19+
enum MyClassTag[T](wrapped: Class[?]):
20+
case IntTag extends MyClassTag[Int](classOf[Int])
21+
case UnitTag extends MyClassTag[Unit](classOf[Unit]) // value case last
22+
23+
extension (ref: AnyRef) def aliases(compare: AnyRef) = assert(ref eq compare, compare)
1224

1325
@main def Test = Using.Manager({ use =>
1426
val buf = use(ByteArrayOutputStream())
1527
val out = use(ObjectOutputStream(buf))
16-
Seq(JColor.Red, SColor.Green, SColorTagged.Blue).foreach(out.writeObject)
28+
Seq(JColor.Red, SColor.Green, SColorTagged.Blue, SColorTagged.Indigo).foreach(out.writeObject)
29+
Seq(Nucleobase.A, Nucleobase.C, Nucleobase.G, Nucleobase.T).foreach(out.writeObject)
30+
Seq(MyClassTag.IntTag, MyClassTag.UnitTag).foreach(out.writeObject)
1731
val read = use(ByteArrayInputStream(buf.toByteArray))
1832
val in = use(ObjectInputStream(read))
19-
val Seq(Red @ _, Green @ _, Blue @ _) = (1 to 3).map(_ => in.readObject)
20-
assert(Red eq JColor.Red, JColor.Red)
21-
assert(Green eq SColor.Green, SColor.Green)
22-
assert(Blue eq SColorTagged.Blue, SColorTagged.Blue)
33+
34+
val Seq(Red @ _, Green @ _, Blue @ _, Indigo @ _) = (1 to 4).map(_ => in.readObject)
35+
Red aliases JColor.Red
36+
Green aliases SColor.Green
37+
Blue aliases SColorTagged.Blue
38+
Indigo aliases SColorTagged.Indigo
39+
40+
val Seq(A @ _, C @ _, G @ _, T @ _) = (1 to 4).map(_ => in.readObject)
41+
A aliases Nucleobase.A
42+
C aliases Nucleobase.C
43+
G aliases Nucleobase.G
44+
T aliases Nucleobase.T
45+
46+
val Seq(IntTag @ _, UnitTag @ _) = (1 to 2).map(_ => in.readObject)
47+
IntTag aliases MyClassTag.IntTag
48+
UnitTag aliases MyClassTag.UnitTag
49+
2350
}).get

0 commit comments

Comments
 (0)