Skip to content

Commit b2570dc

Browse files
bishaboshaodersky
andcommitted
move implementation of ordinal in enums to posttyper
also add a check that ordinal is not implemented by the user, or mixed in by a trait - this is necessary as scala.deriving.Mirror.Sum delegates to ordinal method on an enum. Note that before this commit, as enum cases would previously declare ordinal methods at desugaring (without an override flag), refchecks would issue an override-without-override-modifier error anyway. Co-authored-by: Jamie Thompson <bishbashboshjt@gmail.com> Co-authored-by: Martin Odersky <odersky@gmail.com>
1 parent 16f9b22 commit b2570dc

File tree

7 files changed

+55
-13
lines changed

7 files changed

+55
-13
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,12 @@ object DesugarEnums {
179179
* }
180180
*/
181181
private def enumValueCreator(using Context) = {
182-
val fieldMethods = if isJavaEnum then Nil else ordinalMeth(Ident(nme.ordinalDollar_)) :: Nil
183182
val creator = New(Template(
184183
constr = emptyConstructor,
185184
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
186185
derived = Nil,
187186
self = EmptyValDef,
188-
body = fieldMethods
187+
body = Nil
189188
).withAttachment(ExtendsSingletonMirror, ()))
190189
DefDef(nme.DOLLAR_NEW,
191190
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -270,8 +269,6 @@ object DesugarEnums {
270269
def param(name: TermName, typ: Type)(using Context): ValDef = param(name, TypeTree(typ))
271270
def param(name: TermName, tpt: Tree)(using Context): ValDef = ValDef(name, tpt, EmptyTree).withFlags(Param)
272271

273-
private def isJavaEnum(using Context): Boolean = enumClass.derivesFrom(defn.JavaEnumClass)
274-
275272
def ordinalMeth(body: Tree)(using Context): DefDef =
276273
DefDef(nme.ordinal, Nil, TypeTree(defn.IntType), body).withAddedFlags(Synthetic)
277274

@@ -290,10 +287,8 @@ object DesugarEnums {
290287
expandSimpleEnumCase(name, mods, definesLookups, span)
291288
else {
292289
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups)
293-
val impl1 = cpy.Template(impl)(
294-
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
295-
body = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
296-
).withAttachment(ExtendsSingletonMirror, ())
290+
val impl1 = cpy.Template(impl)(parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue), body = Nil)
291+
.withAttachment(ExtendsSingletonMirror, ())
297292
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
298293
flatTree(vdef :: scaffolding).withSpan(span)
299294
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ class Definitions {
733733
@tu lazy val NoneModule: Symbol = requiredModule("scala.None")
734734

735735
@tu lazy val EnumClass: ClassSymbol = requiredClass("scala.reflect.Enum")
736+
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
736737

737738
@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
738739
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
6666
myCaseSymbols = defn.caseClassSynthesized
6767
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
6868
myEnumValueSymbols = List(defn.Product_productPrefix)
69-
myNonJavaEnumValueSymbols = myEnumValueSymbols :+ defn.Any_toString
69+
myNonJavaEnumValueSymbols = myEnumValueSymbols :+ defn.Any_toString :+ defn.Enum_ordinal
7070
}
7171

7272
def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
@@ -132,6 +132,17 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
132132
else // assume owner is `val Foo = new MyEnum { def ordinal = 0 }`
133133
Literal(Constant(clazz.owner.name.toString))
134134

135+
def ordinalRef: Tree =
136+
if isSimpleEnumValue then // owner is `def $new(_$ordinal: Int, $name: String) = new MyEnum { ... }`
137+
ref(clazz.owner.paramSymss.head.find(_.name == nme.ordinalDollar_).get)
138+
else // val CaseN = new MyEnum { ... def ordinal: Int = n }
139+
val vdef = clazz.owner
140+
val parentEnum = vdef.owner.companionClass
141+
val children = parentEnum.children.zipWithIndex
142+
val candidate: Option[Int] = children.collectFirst { case (child, idx) if child == vdef => idx }
143+
assert(candidate.isDefined, i"could not find child for $vdef")
144+
Literal(Constant(candidate.get))
145+
135146
def toStringBody(vrefss: List[List[Tree]]): Tree =
136147
if (clazz.is(ModuleClass)) ownName
137148
else if (isNonJavaEnumValue) identifierRef
@@ -143,6 +154,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
143154
case nme.toString_ => toStringBody(vrefss)
144155
case nme.equals_ => equalsBody(vrefss.head.head)
145156
case nme.canEqual_ => canEqualBody(vrefss.head.head)
157+
case nme.ordinal => ordinalRef
146158
case nme.productArity => Literal(Constant(accessors.length))
147159
case nme.productPrefix if isEnumValue => nameRef
148160
case nme.productPrefix => ownName

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,8 +1217,19 @@ trait Checking {
12171217
/** 1. Check that all case classes that extend `scala.reflect.Enum` are `enum` cases
12181218
* 2. Check that parameterised `enum` cases do not extend java.lang.Enum.
12191219
* 3. Check that only a static `enum` base class can extend java.lang.Enum.
1220+
* 4. Check that user does not implement an `ordinal` method in the body of an enum class.
12201221
*/
12211222
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = {
1223+
def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = // adapted from SyntheticMembers
1224+
val existing = sym.matchingMember(clazz.thisType)
1225+
if existing != sym && !existing.is(Deferred) then existing else NoSymbol
1226+
def checkExistingOrdinal(using Context) =
1227+
val decl = existingDef(defn.Enum_ordinal, cls.asClass)
1228+
if decl.exists then
1229+
if decl.owner == cls then
1230+
report.error(em"the ordinal method of enum $cls can not be defined by the user", decl.srcPos)
1231+
else
1232+
report.error(em"enum $cls can not inherit the concrete ordinal method of ${decl.owner}", cdef.srcPos)
12221233
def isEnumAnonCls =
12231234
cls.isAnonymousClass
12241235
&& cls.owner.isTerm
@@ -1238,6 +1249,8 @@ trait Checking {
12381249
// this test allows inheriting from `Enum` by hand;
12391250
// see enum-List-control.scala.
12401251
report.error(ClassCannotExtendEnum(cls, firstParent), cdef.srcPos)
1252+
if cls.isEnumClass && !isJavaEnum then
1253+
checkExistingOrdinal
12411254
}
12421255

12431256
/** Check that the firstParent for an enum case derives from the declaring enum class, if not, adds it as a parent

tests/neg/enumsLabel-singleimpl.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
enum Ordinalled {
22

3-
case A // error: method ordinal of type => Int needs `override` modifier
3+
case A
44

5-
def ordinal: Int = -1
5+
def ordinal: Int = -1 // error: the ordinal method of enum class Ordinalled can not be defined by the user
66

77
}
88

99
trait HasOrdinal { def ordinal: Int = 23 }
1010

11-
enum MyEnum extends HasOrdinal {
12-
case Foo // error: method ordinal of type => Int needs `override` modifier
11+
enum MyEnum extends HasOrdinal { // error: enum class MyEnum can not inherit the concrete ordinal method of trait HasOrdinal
12+
case Foo
1313
}

tests/pos/i13554.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object StatusCode:
2+
class Matcher
3+
4+
enum StatusCode(m: StatusCode.Matcher):
5+
case InternalServerError extends StatusCode(???)
6+

tests/pos/i13554a.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object StatusCode:
2+
enum Matcher:
3+
case ServerError extends Matcher
4+
end Matcher
5+
end StatusCode
6+
7+
enum StatusCode(code: Int, m: StatusCode.Matcher):
8+
case InternalServerError extends StatusCode(500, StatusCode.Matcher.ServerError)
9+
end StatusCode
10+
11+
object Main {
12+
def main(args: Array[String]): Unit = {
13+
println(StatusCode.InternalServerError)
14+
}
15+
}

0 commit comments

Comments
 (0)