diff --git a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala index d0bffd488c8e..d029b686e138 100644 --- a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala +++ b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala @@ -125,20 +125,24 @@ object DesugarEnums { /** A creation method for a value of enum type `E`, which is defined as follows: * * private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue { - * def $ordinal = $tag - * override def toString = $name + * def ordinal = _$ordinal // if `E` does not derive from jl.Enum + * override def toString = $name // if `E` does not derive from jl.Enum * $values.register(this) * } */ private def enumValueCreator(using Context) = { - val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_)) - val toStringDef = toStringMeth(Ident(nme.nameDollar)) + val fieldMethods = + if isJavaEnum then Nil + else + val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_)) + val toStringDef = toStringMeth(Ident(nme.nameDollar)) + List(ordinalDef, toStringDef) val creator = New(Template( constr = emptyConstructor, parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil, derived = Nil, self = EmptyValDef, - body = ordinalDef :: toStringDef :: registerCall :: Nil + body = fieldMethods ::: registerCall :: Nil ).withAttachment(ExtendsSingletonMirror, ())) DefDef(nme.DOLLAR_NEW, Nil, List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))), @@ -264,8 +268,10 @@ object DesugarEnums { def param(name: TermName, typ: Type)(using Context) = ValDef(name, TypeTree(typ), EmptyTree).withFlags(Param) + private def isJavaEnum(using Context): Boolean = ctx.owner.linkedClass.derivesFrom(defn.JavaEnumClass) + def ordinalMeth(body: Tree)(using Context): DefDef = - DefDef(nme.ordinalDollar, Nil, Nil, TypeTree(defn.IntType), body) + DefDef(nme.ordinal, Nil, Nil, TypeTree(defn.IntType), body) def toStringMeth(body: Tree)(using Context): DefDef = DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override) @@ -284,12 +290,16 @@ object DesugarEnums { expandSimpleEnumCase(name, mods, span) else { val (tag, scaffolding) = nextOrdinal(CaseKind.Object) - val ordinalDef = ordinalMethLit(tag) - val toStringDef = toStringMethLit(name.toString) + val fieldMethods = + if isJavaEnum then Nil + else + val ordinalDef = ordinalMethLit(tag) + val toStringDef = toStringMethLit(name.toString) + List(ordinalDef, toStringDef) val impl1 = cpy.Template(impl)( parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue), - body = ordinalDef :: toStringDef :: registerCall :: Nil - ).withAttachment(ExtendsSingletonMirror, ()) + body = fieldMethods ::: registerCall :: Nil) + .withAttachment(ExtendsSingletonMirror, ()) val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span)) flatTree(scaffolding ::: vdef :: Nil).withSpan(span) } diff --git a/compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala b/compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala index f3b0a6063574..3c0795fb3093 100644 --- a/compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala +++ b/compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala @@ -15,6 +15,8 @@ import DenotTransformers._ import dotty.tools.dotc.ast.Trees._ import SymUtils._ +import annotation.threadUnsafe + object CompleteJavaEnums { val name: String = "completeJavaEnums" @@ -62,9 +64,10 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase => /** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner` * with given flags (either `Param` or `ParamAccessor`) */ - private def addedParams(owner: Symbol, flag: FlagSet)(using Context): List[ValDef] = { - val nameParam = newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span) - val ordinalParam = newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span) + private def addedParams(owner: Symbol, isLocal: Boolean, flag: FlagSet)(using Context): List[ValDef] = { + val flags = flag | Synthetic | (if isLocal then Private | Deferred else EmptyFlags) + val nameParam = newSymbol(owner, nameParamName, flags, defn.StringType, coord = owner.span) + val ordinalParam = newSymbol(owner, ordinalParamName, flags, defn.IntType, coord = owner.span) List(ValDef(nameParam), ValDef(ordinalParam)) } @@ -85,7 +88,7 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase => val sym = tree.symbol if (sym.isConstructor && sym.owner.derivesFromJavaEnum) val tree1 = cpy.DefDef(tree)( - vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param))) + vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, isLocal=false, Param))) sym.setParamssFromDefs(tree1.tparams, tree1.vparamss) tree1 else tree @@ -107,47 +110,68 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase => } } + private def isJavaEnumValueImpl(cls: Symbol)(using Context): Boolean = + cls.isAnonymousClass + && (((cls.owner.name eq nme.DOLLAR_NEW) && cls.owner.isAllOf(Private|Synthetic)) || cls.owner.isAllOf(EnumCase)) + && cls.owner.owner.linkedClass.derivesFromJavaEnum + + private val enumCaseOrdinals: MutableSymbolMap[Int] = newMutableSymbolMap + + private def registerEnumClass(cls: Symbol)(using Context): Unit = + cls.children.zipWithIndex.foreach(enumCaseOrdinals.put) + + private def ordinalFor(enumCase: Symbol): Int = + enumCaseOrdinals.remove(enumCase).get + /** 1. If this is an enum class, add $name and $ordinal parameters to its * parameter accessors and pass them on to the java.lang.Enum constructor. * - * 2. If this is an anonymous class that implement a value enum case, + * 2. If this is an anonymous class that implement a singleton enum case, * pass $name and $ordinal parameters to the enum superclass. The class * looks like this: * * class $anon extends E(...) { * ... - * def ordinal = N - * def toString = S - * ... * } * * After the transform it is expanded to * - * class $anon extends E(..., N, S) { - * "same as before" + * class $anon extends E(..., $name, _$ordinal) { // if class implements a simple enum case + * "same as before" + * } + * + * class $anon extends E(..., "A", 0) { // if class implements a value enum case `A` with ordinal 0 + * "same as before" * } */ - override def transformTemplate(templ: Template)(using Context): Template = { + override def transformTemplate(templ: Template)(using Context): Tree = { val cls = templ.symbol.owner - if (cls.derivesFromJavaEnum) { + if cls.derivesFromJavaEnum then + registerEnumClass(cls) // invariant: class is visited before cases: see tests/pos/enum-companion-first.scala val (params, rest) = decomposeTemplateBody(templ.body) - val addedDefs = addedParams(cls, ParamAccessor) + val addedDefs = addedParams(cls, isLocal=true, ParamAccessor) val addedSyms = addedDefs.map(_.symbol.entered) val addedForwarders = addedEnumForwarders(cls) cpy.Template(templ)( parents = addEnumConstrArgs(defn.JavaEnumClass, templ.parents, addedSyms.map(ref)), body = params ++ addedDefs ++ addedForwarders ++ rest) - } - else if (cls.isAnonymousClass && ((cls.owner.name eq nme.DOLLAR_NEW) || cls.owner.isAllOf(EnumCase)) && - cls.owner.owner.linkedClass.derivesFromJavaEnum) { - def rhsOf(name: TermName) = - templ.body.collect { - case mdef: DefDef if mdef.name == name => mdef.rhs - }.head - val args = List(rhsOf(nme.toString_), rhsOf(nme.ordinalDollar)) + else if isJavaEnumValueImpl(cls) then + def creatorParamRef(name: TermName) = + ref(cls.owner.paramSymss.head.find(_.name == name).get) + val args = + if cls.owner.isAllOf(EnumCase) then + List(Literal(Constant(cls.owner.name.toString)), Literal(Constant(ordinalFor(cls.owner)))) + else + List(creatorParamRef(nme.nameDollar), creatorParamRef(nme.ordinalDollar_)) cpy.Template(templ)( - parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args)) - } + parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args), + ) + else if cls.linkedClass.derivesFromJavaEnum then + enumCaseOrdinals.clear() // remove simple cases // invariant: companion is visited after cases + templ else templ } + + override def checkPostCondition(tree: Tree)(using Context): Unit = + assert(enumCaseOrdinals.isEmpty, "Java based enum ordinal cache was not cleared") } diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 6c83143a1f7a..37f0128d5520 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -57,7 +57,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) { private var myValueSymbols: List[Symbol] = Nil private var myCaseSymbols: List[Symbol] = Nil private var myCaseModuleSymbols: List[Symbol] = Nil - private var myEnumCaseSymbols: List[Symbol] = Nil private def initSymbols(using Context) = if (myValueSymbols.isEmpty) { @@ -66,13 +65,11 @@ class SyntheticMembers(thisPhase: DenotTransformer) { defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement, defn.Product_productElementName) myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals) - myEnumCaseSymbols = List(defn.Enum_ordinal) } def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols } def caseSymbols(using Context): List[Symbol] = { initSymbols; myCaseSymbols } def caseModuleSymbols(using Context): List[Symbol] = { initSymbols; myCaseModuleSymbols } - def enumCaseSymbols(using Context): List[Symbol] = { initSymbols; myEnumCaseSymbols } private def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = { val existing = sym.matchingMember(clazz.thisType) @@ -96,9 +93,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { val symbolsToSynthesize: List[Symbol] = if (clazz.is(Case)) if (clazz.is(Module)) caseModuleSymbols - else if (isEnumCase) caseSymbols ++ enumCaseSymbols else caseSymbols - else if (isEnumCase) enumCaseSymbols else if (isDerivedValueClass(clazz)) valueSymbols else Nil @@ -128,7 +123,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) { case nme.productPrefix => ownName case nme.productElement => productElementBody(accessors.length, vrefss.head.head) case nme.productElementName => productElementNameBody(accessors.length, vrefss.head.head) - case nme.ordinal => Select(This(clazz), nme.ordinalDollar) } report.log(s"adding $synthetic to $clazz at ${ctx.phase}") synthesizeDef(synthetic, syntheticRHS) diff --git a/docs/docs/reference/enums/desugarEnums.md b/docs/docs/reference/enums/desugarEnums.md index c20d1cf42d02..b9bf2995ffed 100644 --- a/docs/docs/reference/enums/desugarEnums.md +++ b/docs/docs/reference/enums/desugarEnums.md @@ -36,8 +36,8 @@ map into `case class`es or `val`s. ``` expands to a `sealed abstract` class that extends the `scala.Enum` trait and an associated companion object that contains the defined cases, expanded according - to rules (2 - 8). The enum trait starts with a compiler-generated import that imports - the names `` of all cases so that they can be used without prefix in the trait. + to rules (2 - 8). The enum class starts with a compiler-generated import that imports + the names `` of all cases so that they can be used without prefix in the class. ```scala sealed abstract class E ... extends with scala.Enum { import E.{ } @@ -174,13 +174,15 @@ If `E` contains at least one simple case, its companion object will define in ad follows. ```scala private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue { - def $ordinal = $_ordinal - override def toString = $name + def ordinal = _$ordinal // if `E` does not have `java.lang.Enum` as a parent + override def toString = $name // if `E` does not have `java.lang.Enum` as a parent $values.register(this) // register enum value so that `valueOf` and `values` can return it. } ``` -The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it. +The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`. +The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in +`java.lang.Enum`. ### Scopes for Enum Cases diff --git a/library/src-bootstrapped/scala/Enum.scala b/library/src-bootstrapped/scala/Enum.scala index d1e72cb06ff1..ce21eb12cd08 100644 --- a/library/src-bootstrapped/scala/Enum.scala +++ b/library/src-bootstrapped/scala/Enum.scala @@ -5,5 +5,3 @@ trait Enum extends Product, Serializable: /** A number uniquely identifying a case of an enum */ def ordinal: Int - protected def $ordinal: Int - diff --git a/library/src-non-bootstrapped/scala/Enum.scala b/library/src-non-bootstrapped/scala/Enum.scala index 4f8fe897d41c..6b6c2f499ff1 100644 --- a/library/src-non-bootstrapped/scala/Enum.scala +++ b/library/src-non-bootstrapped/scala/Enum.scala @@ -1,7 +1,7 @@ package scala /** A base trait of all enum classes */ -trait Enum: +trait Enum extends Product, Serializable: /** A number uniquely identifying a case of an enum */ def ordinal: Int diff --git a/tests/pos/enum-List-control.scala b/tests/pos/enum-List-control.scala index 2a957a2c4ab2..b52c9b41ec87 100644 --- a/tests/pos/enum-List-control.scala +++ b/tests/pos/enum-List-control.scala @@ -1,7 +1,7 @@ abstract sealed class List[T] extends Enum object List { final class Cons[T](x: T, xs: List[T]) extends List[T] { - def $ordinal = 0 + def ordinal = 0 def canEqual(that: Any): Boolean = that.isInstanceOf[Cons[_]] def productArity: Int = 2 def productElement(n: Int): Any = n match @@ -12,7 +12,7 @@ object List { def apply[T](x: T, xs: List[T]): List[T] = new Cons(x, xs) } final class Nil[T]() extends List[T], runtime.EnumValue { - def $ordinal = 1 + def ordinal = 1 } object Nil { def apply[T](): List[T] = new Nil() diff --git a/tests/pos/enum-companion-first.scala b/tests/pos/enum-companion-first.scala new file mode 100644 index 000000000000..c61051efa0fe --- /dev/null +++ b/tests/pos/enum-companion-first.scala @@ -0,0 +1,9 @@ +object Planet: + final val G = 6.67300E-11 + +enum Planet(mass: Double, radius: Double) extends java.lang.Enum[Planet]: + def surfaceGravity = Planet.G * mass / (radius * radius) + def surfaceWeight(otherMass: Double) = otherMass * surfaceGravity + + case Mercury extends Planet(3.303e+23, 2.4397e6) + case Venus extends Planet(4.869e+24, 6.0518e6) diff --git a/tests/run/enum-ordinal-java/Lib.scala b/tests/run/enum-ordinal-java/Lib.scala new file mode 100644 index 000000000000..75b9003e5553 --- /dev/null +++ b/tests/run/enum-ordinal-java/Lib.scala @@ -0,0 +1,5 @@ +object Lib1: + trait MyJavaEnum[E <: java.lang.Enum[E]] extends java.lang.Enum[E] + +object Lib2: + type JavaEnumAlias[E <: java.lang.Enum[E]] = java.lang.Enum[E] diff --git a/tests/run/enum-ordinal-java/Test.scala b/tests/run/enum-ordinal-java/Test.scala new file mode 100644 index 000000000000..082ea85f7044 --- /dev/null +++ b/tests/run/enum-ordinal-java/Test.scala @@ -0,0 +1,9 @@ +enum Color1 extends Lib1.MyJavaEnum[Color1]: + case Red, Green, Blue + +enum Color2 extends Lib2.JavaEnumAlias[Color2]: + case Red, Green, Blue + +@main def Test = + assert(Color1.Green.ordinal == 1) + assert(Color2.Blue.ordinal == 2) diff --git a/tests/run/enum-values-order.scala b/tests/run/enum-values-order.scala index be2b602f158c..400cdbc8aac5 100644 --- a/tests/run/enum-values-order.scala +++ b/tests/run/enum-values-order.scala @@ -1,8 +1,80 @@ /** immutable hashmaps (as of 2.13 collections) only store up to 4 entries in insertion order */ enum LatinAlphabet { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z } +enum LatinAlphabet2 extends java.lang.Enum[LatinAlphabet2] { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z } + +enum LatinAlphabet3[+T] extends java.lang.Enum[LatinAlphabet3[_]] { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z } + +object Color: + trait Pretty +enum Color extends java.lang.Enum[Color]: + case Red, Green, Blue + case Aqua extends Color with Color.Pretty + case Grey, Black, White + case Emerald extends Color with Color.Pretty + case Brown + @main def Test = - import LatinAlphabet._ - val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z) - assert(ordered sameElements LatinAlphabet.values) + + def testLatin() = + + val ordinals = Seq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25) + val labels = Seq("A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z") + + def testLatin1() = + import LatinAlphabet._ + val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z) + + assert(ordered sameElements LatinAlphabet.values) + assert(ordinals == ordered.map(_.ordinal)) + assert(labels == ordered.map(_.productPrefix)) + + def testLatin2() = + import LatinAlphabet2._ + val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z) + + assert(ordered sameElements LatinAlphabet2.values) + assert(ordinals == ordered.map(_.ordinal)) + assert(labels == ordered.map(_.name)) + + def testLatin3() = + import LatinAlphabet3._ + val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z) + + assert(ordered sameElements LatinAlphabet3.values) + assert(ordinals == ordered.map(_.ordinal)) + assert(labels == ordered.map(_.name)) + + testLatin1() + testLatin2() + testLatin3() + + end testLatin + + def testColor() = + import Color._ + val ordered = Seq(Red, Green, Blue, Aqua, Grey, Black, White, Emerald, Brown) + val ordinals = Seq(0, 1, 2, 3, 4, 5, 6, 7, 8) + val labels = Seq("Red", "Green", "Blue", "Aqua", "Grey", "Black", "White", "Emerald", "Brown") + + assert(ordered sameElements Color.values) + assert(ordinals == ordered.map(_.ordinal)) + assert(labels == ordered.map(_.name)) + + def isPretty(c: Color): Boolean = c match + case _: Pretty => true + case _ => false + + assert(!isPretty(Brown)) + assert(!isPretty(Grey)) + assert(isPretty(Aqua)) + assert(isPretty(Emerald)) + assert(Emerald.getClass != Aqua.getClass) + assert(Aqua.getClass != Grey.getClass) + assert(Grey.getClass == Brown.getClass) + + end testColor + + testLatin() + testColor() diff --git a/tests/semanticdb/metac.expect b/tests/semanticdb/metac.expect index f0b2fb709613..89d15e15ed8d 100644 --- a/tests/semanticdb/metac.expect +++ b/tests/semanticdb/metac.expect @@ -690,7 +690,6 @@ _empty_/Enums.Maybe#``(). => primary ctor _empty_/Enums.Maybe. => final object Maybe _empty_/Enums.Maybe.$values. => val method $values _empty_/Enums.Maybe.Just# => final case enum class Just -_empty_/Enums.Maybe.Just#$ordinal(). => method $ordinal _empty_/Enums.Maybe.Just#[A] => typeparam A _empty_/Enums.Maybe.Just#_1(). => method _1 _empty_/Enums.Maybe.Just#``(). => primary ctor @@ -700,6 +699,7 @@ _empty_/Enums.Maybe.Just#copy$default$1().[A] => typeparam A _empty_/Enums.Maybe.Just#copy(). => method copy _empty_/Enums.Maybe.Just#copy().(value) => param value _empty_/Enums.Maybe.Just#copy().[A] => typeparam A +_empty_/Enums.Maybe.Just#ordinal(). => method ordinal _empty_/Enums.Maybe.Just#value. => val method value _empty_/Enums.Maybe.Just. => final object Just _empty_/Enums.Maybe.Just.apply(). => method apply @@ -787,11 +787,11 @@ _empty_/Enums.`<:<`#[B] => typeparam B _empty_/Enums.`<:<`#``(). => primary ctor _empty_/Enums.`<:<`. => final object <:< _empty_/Enums.`<:<`.Refl# => final case enum class Refl -_empty_/Enums.`<:<`.Refl#$ordinal(). => method $ordinal _empty_/Enums.`<:<`.Refl#[C] => typeparam C _empty_/Enums.`<:<`.Refl#``(). => primary ctor _empty_/Enums.`<:<`.Refl#copy(). => method copy _empty_/Enums.`<:<`.Refl#copy().[C] => typeparam C +_empty_/Enums.`<:<`.Refl#ordinal(). => method ordinal _empty_/Enums.`<:<`.Refl. => final object Refl _empty_/Enums.`<:<`.Refl.apply(). => method apply _empty_/Enums.`<:<`.Refl.apply().[C] => typeparam C