Skip to content

Commit e273e97

Browse files
committed
implement readResolve in terms of fromOrdinalDollar method
value cases are collected in EnumCaseCount attachment during nextOrdinal, and an attachment DefinesEnumLookupMethods determines when the final case has been reached, at which point the methods can be generated that depend on the cases.
1 parent 5d0eea3 commit e273e97

File tree

6 files changed

+122
-41
lines changed

6 files changed

+122
-41
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ object desugar {
476476
val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase)
477477
if (enumCases.isEmpty)
478478
report.error(EnumerationsShouldNotBeEmpty(cdef), namePos)
479+
else
480+
enumCases.last.pushAttachment(DesugarEnums.DefinesEnumLookupMethods, ())
479481
val enumCompanionRef = TermRefTree()
480482
val enumImport =
481483
Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_)))
@@ -568,7 +570,7 @@ object desugar {
568570
// Note: copy default parameters need @uncheckedVariance; see
569571
// neg/t1843-variances.scala for a test case. The test would give
570572
// two errors without @uncheckedVariance, one of them spurious.
571-
val caseClassMeths = {
573+
val (caseClassMeths, enumScaffolding) = {
572574
def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) =
573575
DefDef(name, Nil, Nil, tpt, rhs).withMods(synthetic)
574576

@@ -586,9 +588,11 @@ object desugar {
586588
yield syntheticProperty(selName, caseParams(i).tpt,
587589
Select(This(EmptyTypeIdent), caseParams(i).name))
588590

589-
def enumMeths =
590-
if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil
591-
else Nil
591+
def enumCaseMeths =
592+
if isEnumCase then
593+
val (ordinal, scaffolding) = nextOrdinal(className, CaseKind.Class, definesEnumLookupMethods(cdef))
594+
(ordinalMethLit(ordinal) :: enumLabelLit(className.toString) :: Nil, scaffolding)
595+
else (Nil, Nil)
592596
def copyMeths = {
593597
val hasRepeatedParam = constrVparamss.exists(_.exists {
594598
case ValDef(_, tpt, _) => isRepeated(tpt)
@@ -607,8 +611,9 @@ object desugar {
607611
}
608612

609613
if (isCaseClass)
610-
copyMeths ::: enumMeths ::: productElemMeths
611-
else Nil
614+
val (enumMeths, enumScaffolding) = enumCaseMeths
615+
(copyMeths ::: enumMeths ::: productElemMeths, enumScaffolding)
616+
else (Nil, Nil)
612617
}
613618

614619
var parents1 = parents
@@ -809,7 +814,7 @@ object desugar {
809814
case _ =>
810815
}
811816

812-
flatTree(cdef1 :: companions ::: implicitWrappers)
817+
flatTree(cdef1 :: companions ::: implicitWrappers ::: enumScaffolding)
813818
}.reporting(i"desugared: $result", Printers.desugar)
814819

815820
/** Expand
@@ -862,7 +867,7 @@ object desugar {
862867
else if (isEnumCase) {
863868
typeParamIsReferenced(enumClass.typeParams, Nil, Nil, impl.parents)
864869
// used to check there are no illegal references to enum's type parameters in parents
865-
expandEnumModule(moduleName, impl, mods, mdef.span)
870+
expandEnumModule(moduleName, impl, mods, definesEnumLookupMethods(mdef), mdef.span)
866871
}
867872
else {
868873
val clsName = moduleName.moduleClassName
@@ -990,6 +995,12 @@ object desugar {
990995

991996
private def inventTypeName(tree: Tree)(using Context): String = typeNameExtractor("", tree)
992997

998+
/**This will check if this def tree is marked to define enum lookup methods,
999+
* this is not recommended to call more than once per tree
1000+
*/
1001+
private def definesEnumLookupMethods(ddef: DefTree): Boolean =
1002+
ddef.removeAttachment(DefinesEnumLookupMethods).isDefined
1003+
9931004
/** val p1, ..., pN: T = E
9941005
* ==>
9951006
* makePatDef[[val p1: T1 = E]]; ...; makePatDef[[val pN: TN = E]]
@@ -1001,11 +1012,15 @@ object desugar {
10011012
def patDef(pdef: PatDef)(using Context): Tree = flatTree {
10021013
val PatDef(mods, pats, tpt, rhs) = pdef
10031014
if (mods.isEnumCase)
1004-
pats map {
1005-
case id: Ident =>
1006-
expandSimpleEnumCase(id.name.asTermName, mods,
1015+
def expand(id: Ident, definesLookups: Boolean) =
1016+
expandSimpleEnumCase(id.name.asTermName, mods, definesLookups,
10071017
Span(id.span.start, id.span.end, id.span.start))
1008-
}
1018+
1019+
val ids = pats.asInstanceOf[List[Ident]]
1020+
if definesEnumLookupMethods(pdef) then
1021+
ids.init.map(expand(_, false)) ::: expand(ids.last, true) :: Nil
1022+
else
1023+
ids.map(expand(_, false))
10091024
else {
10101025
val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt))
10111026
pats1 map (makePatDef(pdef, mods, _, rhs))

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ object DesugarEnums {
2020
val Simple, Object, Class: Value = Value
2121
}
2222

23-
/** Attachment containing the number of enum cases and the smallest kind that was seen so far. */
24-
val EnumCaseCount: Property.Key[(Int, DesugarEnums.CaseKind.Value)] = Property.Key()
23+
/** Attachment containing the number of enum cases, the smallest kind that was seen so far,
24+
* and a list of all the value cases with their ordinals.
25+
*/
26+
val EnumCaseCount: Property.Key[(Int, CaseKind.Value, List[(Int, TermName)])] = Property.Key()
27+
28+
/** Attachment signalling that when this definition is desugared, it should add any additional
29+
* lookup methods for enums.
30+
*/
31+
val DefinesEnumLookupMethods: Property.Key[Unit] = Property.Key()
2532

2633
/** The enumeration class that belongs to an enum case. This works no matter
2734
* whether the case is still in the enum class or it has been transferred to the
@@ -122,6 +129,21 @@ object DesugarEnums {
122129
valueOfDef :: Nil
123130
}
124131

132+
private def enumLookupMethods(cases: List[(Int, TermName)])(using Context): List[Tree] =
133+
if isJavaEnum || cases.isEmpty then Nil
134+
else
135+
val defaultCase =
136+
val ord = Ident(nme.ordinal)
137+
val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil)))
138+
CaseDef(ord, EmptyTree, err)
139+
val valueCases = cases.map((i, name) =>
140+
CaseDef(Literal(Constant(i)), EmptyTree, Ident(name))
141+
) ::: defaultCase :: Nil
142+
val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil),
143+
rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases))
144+
.withFlags(Synthetic | Private)
145+
fromOrdinalDef :: Nil
146+
125147
/** A creation method for a value of enum type `E`, which is defined as follows:
126148
*
127149
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
@@ -256,16 +278,22 @@ object DesugarEnums {
256278
* - scaffolding containing the necessary definitions for singleton enum cases
257279
* unless that scaffolding was already generated by a previous call to `nextEnumKind`.
258280
*/
259-
def nextOrdinal(kind: CaseKind.Value)(using Context): (Int, List[Tree]) = {
260-
val (count, seenKind) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class))
261-
val minKind = if (kind < seenKind) kind else seenKind
262-
ctx.tree.pushAttachment(EnumCaseCount, (count + 1, minKind))
263-
val scaffolding =
281+
def nextOrdinal(name: Name, kind: CaseKind.Value, definesLookups: Boolean)(using Context): (Int, List[Tree]) = {
282+
val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, Nil))
283+
val minKind = if kind < seenKind then kind else seenKind
284+
val cases = name match
285+
case name: TermName => (ordinal, name) :: seenCases
286+
case _ => seenCases
287+
ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, cases))
288+
val scaffolding0 =
264289
if (kind >= seenKind) Nil
265290
else if (kind == CaseKind.Object) enumScaffolding
266291
else if (seenKind == CaseKind.Object) enumValueCreator :: Nil
267292
else enumScaffolding :+ enumValueCreator
268-
(count, scaffolding)
293+
val scaffolding =
294+
if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse)
295+
else scaffolding0
296+
(ordinal, scaffolding)
269297
}
270298

271299
def param(name: TermName, typ: Type)(using Context) =
@@ -286,13 +314,13 @@ object DesugarEnums {
286314
enumLabelMeth(Literal(Constant(name)))
287315

288316
/** Expand a module definition representing a parameterless enum case */
289-
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
317+
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree = {
290318
assert(impl.body.isEmpty)
291319
if (!enumClass.exists) EmptyTree
292320
else if (impl.parents.isEmpty)
293-
expandSimpleEnumCase(name, mods, span)
321+
expandSimpleEnumCase(name, mods, definesLookups, span)
294322
else {
295-
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
323+
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups)
296324
val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
297325
val enumLabelDef = enumLabelLit(name.toString)
298326
val impl1 = cpy.Template(impl)(
@@ -305,15 +333,15 @@ object DesugarEnums {
305333
}
306334

307335
/** Expand a simple enum case */
308-
def expandSimpleEnumCase(name: TermName, mods: Modifiers, span: Span)(using Context): Tree =
336+
def expandSimpleEnumCase(name: TermName, mods: Modifiers, definesLookups: Boolean, span: Span)(using Context): Tree =
309337
if (!enumClass.exists) EmptyTree
310338
else if (enumClass.typeParams.nonEmpty) {
311339
val parent = interpolatedEnumParent(span)
312340
val impl = Template(emptyConstructor, parent :: Nil, Nil, EmptyValDef, Nil)
313-
expandEnumModule(name, impl, mods, span)
341+
expandEnumModule(name, impl, mods, definesLookups, span)
314342
}
315343
else {
316-
val (tag, scaffolding) = nextOrdinal(CaseKind.Simple)
344+
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Simple, definesLookups)
317345
val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), Literal(Constant(name.toString))))
318346
val vdef = ValDef(name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue, span))
319347
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ object StdNames {
615615
val using: N = "using"
616616
val value: N = "value"
617617
val valueOf : N = "valueOf"
618+
val fromOrdinalDollar: N = "$fromOrdinal"
618619
val values: N = "values"
619620
val view_ : N = "view"
620621
val wait_ : N = "wait"
@@ -623,6 +624,7 @@ object StdNames {
623624
val WorksheetWrapper: N = "WorksheetWrapper"
624625
val wrap: N = "wrap"
625626
val writeReplace: N = "writeReplace"
627+
val readResolve: N = "readResolve"
626628
val zero: N = "zero"
627629
val zip: N = "zip"
628630
val nothingRuntimeClass: N = "scala.runtime.Nothing$"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,19 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
373373
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
374374
.exists
375375

376+
private def hasReadResolve(clazz: ClassSymbol)(using Context): Boolean =
377+
clazz.membersNamed(nme.readResolve)
378+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
379+
.exists
380+
376381
private def writeReplaceDef(clazz: ClassSymbol)(using Context): TermSymbol =
377382
newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
378383
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
379384

385+
private def readResolveDef(clazz: ClassSymbol)(using Context): TermSymbol =
386+
newSymbol(clazz, nme.readResolve, Method | Private | Synthetic,
387+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
388+
380389
/** If this is a static object `Foo`, add the method:
381390
*
382391
* private def writeReplace(): AnyRef =
@@ -405,22 +414,22 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
405414
/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
406415
* and not deriving from `java.lang.Enum` add the method:
407416
*
408-
* private def writeReplace(): AnyRef =
409-
* new scala.runtime.EnumValueSerializationProxy(classOf[MyEnum], this.ordinal)
417+
* private def readResolve(): AnyRef =
418+
* MyEnum.$fromOrdinal(this.ordinal)
410419
*
411420
* unless an implementation already exists, otherwise do nothing.
412421
*/
413422
def serializableEnumValueMethod(clazz: ClassSymbol)(using Context): List[Tree] =
414423
if clazz.isEnumValueImplementation
415424
&& !clazz.derivesFrom(defn.JavaEnumClass)
416425
&& clazz.isSerializable
417-
&& !hasWriteReplace(clazz)
426+
&& !hasReadResolve(clazz)
418427
then
419428
List(
420-
DefDef(writeReplaceDef(clazz),
421-
_ => New(defn.EnumValueSerializationProxyClass.typeRef,
422-
defn.EnumValueSerializationProxyConstructor,
423-
List(Literal(Constant(clazz.classParents.head)), This(clazz).select(nme.ordinal).ensureApplied)))
429+
DefDef(readResolveDef(clazz),
430+
_ => ref(clazz.owner.owner.sourceModule)
431+
.select(nme.fromOrdinalDollar)
432+
.appliedTo(This(clazz).select(nme.ordinal).ensureApplied))
424433
.withSpan(ctx.owner.span.focus))
425434
else
426435
Nil

tests/run/enums-serialization-compat.scala

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,49 @@ import java.io._
22
import scala.util.Using
33

44
enum JColor extends java.lang.Enum[JColor]:
5-
case Red
5+
case Red // java enum has magic JVM support
66

77
enum SColor:
8-
case Green
8+
case Green // simple case last
99

1010
enum SColorTagged[T]:
11-
case Blue extends SColorTagged[Unit]
11+
case Blue extends SColorTagged[Unit]
12+
case Rgb(r: Byte, g: Byte, b: Byte) extends SColorTagged[(Byte, Byte, Byte)] // mixing pattern kinds
13+
case Indigo extends SColorTagged[Unit]
14+
case Cmyk(c: Byte, m: Byte, y: Byte, k: Byte) extends SColorTagged[(Byte, Byte, Byte, Byte)] // class case last
15+
16+
enum Nucleobase:
17+
case A,C,G,T // patdef last
18+
19+
enum MyClassTag[T](wrapped: Class[?]):
20+
case IntTag extends MyClassTag[Int](classOf[Int])
21+
case UnitTag extends MyClassTag[Unit](classOf[Unit]) // value case last
22+
23+
extension (ref: AnyRef) def aliases(compare: AnyRef) = assert(ref eq compare, compare)
1224

1325
@main def Test = Using.Manager({ use =>
1426
val buf = use(ByteArrayOutputStream())
1527
val out = use(ObjectOutputStream(buf))
16-
Seq(JColor.Red, SColor.Green, SColorTagged.Blue).foreach(out.writeObject)
28+
Seq(JColor.Red, SColor.Green, SColorTagged.Blue, SColorTagged.Indigo).foreach(out.writeObject)
29+
Seq(Nucleobase.A, Nucleobase.C, Nucleobase.G, Nucleobase.T).foreach(out.writeObject)
30+
Seq(MyClassTag.IntTag, MyClassTag.UnitTag).foreach(out.writeObject)
1731
val read = use(ByteArrayInputStream(buf.toByteArray))
1832
val in = use(ObjectInputStream(read))
19-
val Seq(Red @ _, Green @ _, Blue @ _) = (1 to 3).map(_ => in.readObject)
20-
assert(Red eq JColor.Red, JColor.Red)
21-
assert(Green eq SColor.Green, SColor.Green)
22-
assert(Blue eq SColorTagged.Blue, SColorTagged.Blue)
33+
34+
val Seq(Red @ _, Green @ _, Blue @ _, Indigo @ _) = (1 to 4).map(_ => in.readObject)
35+
Red aliases JColor.Red
36+
Green aliases SColor.Green
37+
Blue aliases SColorTagged.Blue
38+
Indigo aliases SColorTagged.Indigo
39+
40+
val Seq(A @ _, C @ _, G @ _, T @ _) = (1 to 4).map(_ => in.readObject)
41+
A aliases Nucleobase.A
42+
C aliases Nucleobase.C
43+
G aliases Nucleobase.G
44+
T aliases Nucleobase.T
45+
46+
val Seq(IntTag @ _, UnitTag @ _) = (1 to 2).map(_ => in.readObject)
47+
IntTag aliases MyClassTag.IntTag
48+
UnitTag aliases MyClassTag.UnitTag
49+
2350
}).get

0 commit comments

Comments
 (0)