Skip to content

remove dollar ordinal from Enum #9539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,24 @@ object DesugarEnums {
/** A creation method for a value of enum type `E`, which is defined as follows:
*
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
* def $ordinal = $tag
* override def toString = $name
* def ordinal = _$ordinal // if `E` does not derive from jl.Enum
* override def toString = $name // if `E` does not derive from jl.Enum
* $values.register(this)
* }
*/
private def enumValueCreator(using Context) = {
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
val toStringDef = toStringMeth(Ident(nme.nameDollar))
val fieldMethods =
if isJavaEnum then Nil
else
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
val toStringDef = toStringMeth(Ident(nme.nameDollar))
List(ordinalDef, toStringDef)
val creator = New(Template(
constr = emptyConstructor,
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
derived = Nil,
self = EmptyValDef,
body = ordinalDef :: toStringDef :: registerCall :: Nil
body = fieldMethods ::: registerCall :: Nil
).withAttachment(ExtendsSingletonMirror, ()))
DefDef(nme.DOLLAR_NEW, Nil,
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
Expand Down Expand Up @@ -264,8 +268,10 @@ object DesugarEnums {
def param(name: TermName, typ: Type)(using Context) =
ValDef(name, TypeTree(typ), EmptyTree).withFlags(Param)

private def isJavaEnum(using Context): Boolean = ctx.owner.linkedClass.derivesFrom(defn.JavaEnumClass)

def ordinalMeth(body: Tree)(using Context): DefDef =
DefDef(nme.ordinalDollar, Nil, Nil, TypeTree(defn.IntType), body)
DefDef(nme.ordinal, Nil, Nil, TypeTree(defn.IntType), body)

def toStringMeth(body: Tree)(using Context): DefDef =
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override)
Expand All @@ -284,12 +290,16 @@ object DesugarEnums {
expandSimpleEnumCase(name, mods, span)
else {
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
val ordinalDef = ordinalMethLit(tag)
val toStringDef = toStringMethLit(name.toString)
val fieldMethods =
if isJavaEnum then Nil
else
val ordinalDef = ordinalMethLit(tag)
val toStringDef = toStringMethLit(name.toString)
List(ordinalDef, toStringDef)
val impl1 = cpy.Template(impl)(
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
body = ordinalDef :: toStringDef :: registerCall :: Nil
).withAttachment(ExtendsSingletonMirror, ())
body = fieldMethods ::: registerCall :: Nil)
.withAttachment(ExtendsSingletonMirror, ())
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
}
Expand Down
70 changes: 47 additions & 23 deletions compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import DenotTransformers._
import dotty.tools.dotc.ast.Trees._
import SymUtils._

import annotation.threadUnsafe

object CompleteJavaEnums {
val name: String = "completeJavaEnums"

Expand Down Expand Up @@ -62,9 +64,10 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
/** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner`
* with given flags (either `Param` or `ParamAccessor`)
*/
private def addedParams(owner: Symbol, flag: FlagSet)(using Context): List[ValDef] = {
val nameParam = newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span)
val ordinalParam = newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span)
private def addedParams(owner: Symbol, isLocal: Boolean, flag: FlagSet)(using Context): List[ValDef] = {
val flags = flag | Synthetic | (if isLocal then Private | Deferred else EmptyFlags)
val nameParam = newSymbol(owner, nameParamName, flags, defn.StringType, coord = owner.span)
val ordinalParam = newSymbol(owner, ordinalParamName, flags, defn.IntType, coord = owner.span)
List(ValDef(nameParam), ValDef(ordinalParam))
}

Expand All @@ -85,7 +88,7 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
val sym = tree.symbol
if (sym.isConstructor && sym.owner.derivesFromJavaEnum)
val tree1 = cpy.DefDef(tree)(
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param)))
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, isLocal=false, Param)))
sym.setParamssFromDefs(tree1.tparams, tree1.vparamss)
tree1
else tree
Expand All @@ -107,47 +110,68 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
}
}

private def isJavaEnumValueImpl(cls: Symbol)(using Context): Boolean =
cls.isAnonymousClass
&& (((cls.owner.name eq nme.DOLLAR_NEW) && cls.owner.isAllOf(Private|Synthetic)) || cls.owner.isAllOf(EnumCase))
&& cls.owner.owner.linkedClass.derivesFromJavaEnum

private val enumCaseOrdinals: MutableSymbolMap[Int] = newMutableSymbolMap

private def registerEnumClass(cls: Symbol)(using Context): Unit =
cls.children.zipWithIndex.foreach(enumCaseOrdinals.put)

private def ordinalFor(enumCase: Symbol): Int =
enumCaseOrdinals.remove(enumCase).get

/** 1. If this is an enum class, add $name and $ordinal parameters to its
* parameter accessors and pass them on to the java.lang.Enum constructor.
*
* 2. If this is an anonymous class that implement a value enum case,
* 2. If this is an anonymous class that implement a singleton enum case,
* pass $name and $ordinal parameters to the enum superclass. The class
* looks like this:
*
* class $anon extends E(...) {
* ...
* def ordinal = N
* def toString = S
* ...
* }
*
* After the transform it is expanded to
*
* class $anon extends E(..., N, S) {
* "same as before"
* class $anon extends E(..., $name, _$ordinal) { // if class implements a simple enum case
* "same as before"
* }
*
* class $anon extends E(..., "A", 0) { // if class implements a value enum case `A` with ordinal 0
* "same as before"
* }
*/
override def transformTemplate(templ: Template)(using Context): Template = {
override def transformTemplate(templ: Template)(using Context): Tree = {
val cls = templ.symbol.owner
if (cls.derivesFromJavaEnum) {
if cls.derivesFromJavaEnum then
registerEnumClass(cls) // invariant: class is visited before cases: see tests/pos/enum-companion-first.scala
val (params, rest) = decomposeTemplateBody(templ.body)
val addedDefs = addedParams(cls, ParamAccessor)
val addedDefs = addedParams(cls, isLocal=true, ParamAccessor)
val addedSyms = addedDefs.map(_.symbol.entered)
val addedForwarders = addedEnumForwarders(cls)
cpy.Template(templ)(
parents = addEnumConstrArgs(defn.JavaEnumClass, templ.parents, addedSyms.map(ref)),
body = params ++ addedDefs ++ addedForwarders ++ rest)
}
else if (cls.isAnonymousClass && ((cls.owner.name eq nme.DOLLAR_NEW) || cls.owner.isAllOf(EnumCase)) &&
cls.owner.owner.linkedClass.derivesFromJavaEnum) {
def rhsOf(name: TermName) =
templ.body.collect {
case mdef: DefDef if mdef.name == name => mdef.rhs
}.head
val args = List(rhsOf(nme.toString_), rhsOf(nme.ordinalDollar))
else if isJavaEnumValueImpl(cls) then
def creatorParamRef(name: TermName) =
ref(cls.owner.paramSymss.head.find(_.name == name).get)
val args =
if cls.owner.isAllOf(EnumCase) then
List(Literal(Constant(cls.owner.name.toString)), Literal(Constant(ordinalFor(cls.owner))))
else
List(creatorParamRef(nme.nameDollar), creatorParamRef(nme.ordinalDollar_))
cpy.Template(templ)(
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
}
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args),
)
else if cls.linkedClass.derivesFromJavaEnum then
enumCaseOrdinals.clear() // remove simple cases // invariant: companion is visited after cases
templ
else templ
}

override def checkPostCondition(tree: Tree)(using Context): Unit =
assert(enumCaseOrdinals.isEmpty, "Java based enum ordinal cache was not cleared")
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
private var myValueSymbols: List[Symbol] = Nil
private var myCaseSymbols: List[Symbol] = Nil
private var myCaseModuleSymbols: List[Symbol] = Nil
private var myEnumCaseSymbols: List[Symbol] = Nil

private def initSymbols(using Context) =
if (myValueSymbols.isEmpty) {
Expand All @@ -66,13 +65,11 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement,
defn.Product_productElementName)
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
myEnumCaseSymbols = List(defn.Enum_ordinal)
}

def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
def caseSymbols(using Context): List[Symbol] = { initSymbols; myCaseSymbols }
def caseModuleSymbols(using Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
def enumCaseSymbols(using Context): List[Symbol] = { initSymbols; myEnumCaseSymbols }

private def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = {
val existing = sym.matchingMember(clazz.thisType)
Expand All @@ -96,9 +93,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
val symbolsToSynthesize: List[Symbol] =
if (clazz.is(Case))
if (clazz.is(Module)) caseModuleSymbols
else if (isEnumCase) caseSymbols ++ enumCaseSymbols
else caseSymbols
else if (isEnumCase) enumCaseSymbols
else if (isDerivedValueClass(clazz)) valueSymbols
else Nil

Expand Down Expand Up @@ -128,7 +123,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
case nme.productPrefix => ownName
case nme.productElement => productElementBody(accessors.length, vrefss.head.head)
case nme.productElementName => productElementNameBody(accessors.length, vrefss.head.head)
case nme.ordinal => Select(This(clazz), nme.ordinalDollar)
}
report.log(s"adding $synthetic to $clazz at ${ctx.phase}")
synthesizeDef(synthetic, syntheticRHS)
Expand Down
12 changes: 7 additions & 5 deletions docs/docs/reference/enums/desugarEnums.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ map into `case class`es or `val`s.
```
expands to a `sealed abstract` class that extends the `scala.Enum` trait and
an associated companion object that contains the defined cases, expanded according
to rules (2 - 8). The enum trait starts with a compiler-generated import that imports
the names `<caseIds>` of all cases so that they can be used without prefix in the trait.
to rules (2 - 8). The enum class starts with a compiler-generated import that imports
the names `<caseIds>` of all cases so that they can be used without prefix in the class.
```scala
sealed abstract class E ... extends <parents> with scala.Enum {
import E.{ <caseIds> }
Expand Down Expand Up @@ -174,13 +174,15 @@ If `E` contains at least one simple case, its companion object will define in ad
follows.
```scala
private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue {
def $ordinal = $_ordinal
override def toString = $name
def ordinal = _$ordinal // if `E` does not have `java.lang.Enum` as a parent
override def toString = $name // if `E` does not have `java.lang.Enum` as a parent
$values.register(this) // register enum value so that `valueOf` and `values` can return it.
}
```

The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.
The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in
`java.lang.Enum`.

### Scopes for Enum Cases

Expand Down
2 changes: 0 additions & 2 deletions library/src-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,3 @@ trait Enum extends Product, Serializable:

/** A number uniquely identifying a case of an enum */
def ordinal: Int
protected def $ordinal: Int

2 changes: 1 addition & 1 deletion library/src-non-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package scala

/** A base trait of all enum classes */
trait Enum:
trait Enum extends Product, Serializable:

/** A number uniquely identifying a case of an enum */
def ordinal: Int
Expand Down
4 changes: 2 additions & 2 deletions tests/pos/enum-List-control.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
abstract sealed class List[T] extends Enum
object List {
final class Cons[T](x: T, xs: List[T]) extends List[T] {
def $ordinal = 0
def ordinal = 0
def canEqual(that: Any): Boolean = that.isInstanceOf[Cons[_]]
def productArity: Int = 2
def productElement(n: Int): Any = n match
Expand All @@ -12,7 +12,7 @@ object List {
def apply[T](x: T, xs: List[T]): List[T] = new Cons(x, xs)
}
final class Nil[T]() extends List[T], runtime.EnumValue {
def $ordinal = 1
def ordinal = 1
}
object Nil {
def apply[T](): List[T] = new Nil()
Expand Down
9 changes: 9 additions & 0 deletions tests/pos/enum-companion-first.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
object Planet:
final val G = 6.67300E-11

enum Planet(mass: Double, radius: Double) extends java.lang.Enum[Planet]:
def surfaceGravity = Planet.G * mass / (radius * radius)
def surfaceWeight(otherMass: Double) = otherMass * surfaceGravity

case Mercury extends Planet(3.303e+23, 2.4397e6)
case Venus extends Planet(4.869e+24, 6.0518e6)
5 changes: 5 additions & 0 deletions tests/run/enum-ordinal-java/Lib.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
object Lib1:
trait MyJavaEnum[E <: java.lang.Enum[E]] extends java.lang.Enum[E]

object Lib2:
type JavaEnumAlias[E <: java.lang.Enum[E]] = java.lang.Enum[E]
9 changes: 9 additions & 0 deletions tests/run/enum-ordinal-java/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
enum Color1 extends Lib1.MyJavaEnum[Color1]:
case Red, Green, Blue

enum Color2 extends Lib2.JavaEnumAlias[Color2]:
case Red, Green, Blue

@main def Test =
assert(Color1.Green.ordinal == 1)
assert(Color2.Blue.ordinal == 2)
78 changes: 75 additions & 3 deletions tests/run/enum-values-order.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,80 @@
/** immutable hashmaps (as of 2.13 collections) only store up to 4 entries in insertion order */
enum LatinAlphabet { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z }

enum LatinAlphabet2 extends java.lang.Enum[LatinAlphabet2] { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z }

enum LatinAlphabet3[+T] extends java.lang.Enum[LatinAlphabet3[_]] { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z }

object Color:
trait Pretty
enum Color extends java.lang.Enum[Color]:
case Red, Green, Blue
case Aqua extends Color with Color.Pretty
case Grey, Black, White
case Emerald extends Color with Color.Pretty
case Brown

@main def Test =
import LatinAlphabet._
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)

assert(ordered sameElements LatinAlphabet.values)

def testLatin() =

val ordinals = Seq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25)
val labels = Seq("A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z")

def testLatin1() =
import LatinAlphabet._
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)

assert(ordered sameElements LatinAlphabet.values)
assert(ordinals == ordered.map(_.ordinal))
assert(labels == ordered.map(_.productPrefix))

def testLatin2() =
import LatinAlphabet2._
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)

assert(ordered sameElements LatinAlphabet2.values)
assert(ordinals == ordered.map(_.ordinal))
assert(labels == ordered.map(_.name))

def testLatin3() =
import LatinAlphabet3._
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)

assert(ordered sameElements LatinAlphabet3.values)
assert(ordinals == ordered.map(_.ordinal))
assert(labels == ordered.map(_.name))

testLatin1()
testLatin2()
testLatin3()

end testLatin

def testColor() =
import Color._
val ordered = Seq(Red, Green, Blue, Aqua, Grey, Black, White, Emerald, Brown)
val ordinals = Seq(0, 1, 2, 3, 4, 5, 6, 7, 8)
val labels = Seq("Red", "Green", "Blue", "Aqua", "Grey", "Black", "White", "Emerald", "Brown")

assert(ordered sameElements Color.values)
assert(ordinals == ordered.map(_.ordinal))
assert(labels == ordered.map(_.name))

def isPretty(c: Color): Boolean = c match
case _: Pretty => true
case _ => false

assert(!isPretty(Brown))
assert(!isPretty(Grey))
assert(isPretty(Aqua))
assert(isPretty(Emerald))
assert(Emerald.getClass != Aqua.getClass)
assert(Aqua.getClass != Grey.getClass)
assert(Grey.getClass == Brown.getClass)

end testColor

testLatin()
testColor()
Loading