@@ -16,7 +16,6 @@ import dotty.tools.dotc.ast.Trees._
16
16
import SymUtils ._
17
17
18
18
import annotation .threadUnsafe
19
- import collection .mutable
20
19
21
20
object CompleteJavaEnums {
22
21
val name : String = " completeJavaEnums"
@@ -117,12 +116,13 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
117
116
&& cls.owner.owner.linkedClass.derivesFromJavaEnum
118
117
119
118
@ threadUnsafe
120
- private lazy val enumCaseOrdinals : mutable. Map [ Symbol , Int ] = mutable. AnyRefMap .empty
119
+ private lazy val enumCaseOrdinals : MutableSymbolMap [ Int ] = newMutableSymbolMap
121
120
122
121
private def registerEnumClass (cls : Symbol )(using Context ): Unit =
123
122
cls.children.zipWithIndex.foreach(enumCaseOrdinals.put)
124
123
125
- private def ordinalFor (enumCase : Symbol ): Int = enumCaseOrdinals.remove(enumCase).get
124
+ private def ordinalFor (enumCase : Symbol ): Int =
125
+ enumCaseOrdinals.remove(enumCase).get
126
126
127
127
/** 1. If this is an enum class, add $name and $ordinal parameters to its
128
128
* parameter accessors and pass them on to the java.lang.Enum constructor.
@@ -145,17 +145,20 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
145
145
* "same as before"
146
146
* }
147
147
*/
148
- override def transformTemplate (templ : Template )(using Context ): Template = {
148
+ override def transformTemplate (templ : Template )(using Context ): Tree = {
149
149
val cls = templ.symbol.owner
150
150
if cls.derivesFromJavaEnum then
151
- registerEnumClass(cls)
151
+ registerEnumClass(cls) // invariant: class is visited before cases: see tests/pos/enum-companion-first.scala
152
152
val (params, rest) = decomposeTemplateBody(templ.body)
153
153
val addedDefs = addedParams(cls, isLocal= true , ParamAccessor )
154
154
val addedSyms = addedDefs.map(_.symbol.entered)
155
155
val addedForwarders = addedEnumForwarders(cls)
156
156
cpy.Template (templ)(
157
157
parents = addEnumConstrArgs(defn.JavaEnumClass , templ.parents, addedSyms.map(ref)),
158
158
body = params ++ addedDefs ++ addedForwarders ++ rest)
159
+ else if cls.linkedClass.derivesFromJavaEnum then
160
+ enumCaseOrdinals.clear() // remove simple cases // invariant: companion is visited after cases
161
+ templ
159
162
else if isJavaEnumValueImpl(cls) then
160
163
def creatorParamRef (name : TermName ) =
161
164
ref(cls.owner.paramSymss.head.find(_.name == name).get)
0 commit comments